前言
1: 数据类型
2: 常用API
参考:
一 数据类型
除了string ,相对于Numpy PyTorch 都能找到对应的数据类型
1.1 常用的Data type
常用的5种:
IntTensor, LogTensor, ByteTensor, DoubleTensor, FloatTensor
- # -*- coding: utf-8 -*-
- """
- Created on Tue Nov 29 16:20:52 2022
- @author: chengxf2
- """
- import torch
-
- def checktype():
-
- a = torch.randn(2,3)
-
- print("\n\t type:",a.type())
-
-
- bFloat = isinstance(a, torch.FloatTensor)
-
- print("\n\t bFloat",bFloat)
-
- bDouble = isinstance(a, torch.cuda.DoubleTensor)
- print("\n\t bDouble: ",bDouble)
- print("\n\t sp",a.shape)
-
- dim = len(a.shape)
- print("\n\t 维度",dim)
-
-
- b = torch.tensor(1)
- print("\n\t b: ",b.type())
- print("\n\t value ",b.item())
- print("\n\t sp",b.shape)
-
-
-
-
-
-
-
- if __name__ == "__main__":
-
- checktype()
二 常用API 函数
Function | 说明 |
Dim() | 维度 |
Numel() | 占用内存 |
shape | 数组的每个维度的长度 |
Size(0) | 对应维度的大小 |
FromNumpy() | Numpy to Pytorch |
Numpy() | Pytorch to Numpy |
item | 取元素 |
torch.FloatTensor | Float 类型 |
torch.DoubleTensor | Double 类型 |
torch.IntTensor | Int 类型 |
torch.ByteTensor | Byte 类型 |