import torch a = torch.randn(3, 200, 200) print(a.dtype) b = a.type(torch.float16) print(b.dtype) c = a.type(torch.int32) print(c.dtype) d = a.type(torch.long) print(d.dtype) e = a.type(torch.float32) print(e.dtype)
京公网安备 11010502049817号