torch.nn.``Module
[来源]
所有神经网络模块的基类。
所有神经网络的模型也应该继承这个类。
模块还可以包含其他模块,允许将它们嵌套在树结构中。您可以将子模块分配为常规属性:
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
'''
前向传播
输入,x,经过卷积、非线性、卷积、非线性,最后输出。
其中self.conv1(x)为卷积,
relu() 为非线性
'''
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
以这种方式分配的子模块将被注册,并且在您调用时也会转换其参数to()
等。
根据以上官网内容,实现简单的神经网络:
import torch
from torch import nn
class Test(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self,input):
output = input+1
return output
test = Test()
x = torch.tensor(1.0)
output = test(x)
print(output) # tensor(2.)
它是实现卷积计算的核心步骤函数,
首先,卷积的意思就是从图像的像素点上抽象出特征,然而这个特征抽取的过程并不是传统意义上的人工的抽取,而是通过卷积核进行自动抽取,当然这种抽取的结果对于人类来说也很难讲有什么能够解释的意义。数字图像(比如说一张照片),可以看做是一个矩阵,每一个像素点都是矩阵中的一个元素:特别的,如果照片是黑白的,那么可以看做是一个length×width×1的三维矩阵;如果是彩色的(比如RGB)那么就可以看做一个length×width×3的三维矩阵。卷积核像一小块方形的抹布,在图片上由上到下从左到右均匀抹过,并不时的停下来,当抹布停下来的时候,抹布上的点就会和其覆盖的点进行计算,得到一个值,这个值就将成为卷积计算输出矩阵的对应点的值。
从直观上看,卷积的过程相当于将图片“浓缩了”,当然在浓缩的过程中,厚度是可以变的。
torch.nn.Conv2d
( in_channels , out_channels , kernel_size , stride = 1 , padding = 0 , dilation = 1 , groups = 1 , bias = True , padding_mode = ‘zeros’ , device = None , dtype = None )
官网链接
该模块支持[TensorFloat32]
stride
控制互相关、单个数字或元组的步幅。padding
控制应用于输入的填充量。它可以是一个字符串 {‘valid’, ‘same’} 或一个整数元组,给出在两边应用的隐式填充量。dilation
控制内核点之间的间距;也称为 à trous 算法。很难描述,但是这个链接 很好地可视化了dilation
它的作用。groups
控制输入和输出之间的连接。 in_channels
并且out_channels
都必须能被 整除 groups
。import torch
import torch.nn.functional as F
input = torch.tensor([[1, 2, 0, 3, 1],
[0, 1, 2, 3, 1],
[1, 2, 1, 0, 0],
[5, 2, 3, 1, 1],
[2, 1, 0, 1, 1]])
kernel = torch.tensor([[1, 2, 1],
[0, 1, 0],
[2, 1, 0]])
input = torch.reshape(input, (1, 1, 5, 5))
kernel = torch.reshape(kernel, (1, 1, 3, 3))
print(input.shape)
print(kernel.shape)
output = F.conv2d(input, kernel, stride=1)
print(output)
output2 = F.conv2d(input, kernel, stride=2)
print(output2)
output3 = F.conv2d(input, kernel, stride=1, padding=1)
print(output3)