- import torch
- import torch.nn.functional as F
-
- class Residual_Block(torch.nn.Module):
- def __init__(self,channels):
- super(Residual_Block, self).__init__()
- self.channels = channels #因为输入的x与输出的y要进行加法,需要保证他们的 C、W、H都一样
-
- self.conv_1 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
- self.conv_2 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)
-
- def forward(self,x):
- y = F.relu(self.conv_1(x))
- y = self.conv_2(y)
- return F.relu(y+x) #对x和y的和再做激活
-
<