python中加载Tensor和储存Tensor的代码
# 以下代码是python中加载Tensor
import numpy as np
def load_tensor(file):
with open(file, "rb") as f:
binary_data = f.read()
magic_number, ndims, dtype = np.frombuffer(binary_data, np.uint32, count=3, offset=0)
assert magic_number == 0xFCCFE2E2, f"{file} not a tensor file."
dims = np.frombuffer(binary_data, np.uint32, count=ndims, offset=3 * 4)
if dtype == 0:
np_dtype = np.float32
elif dtype == 1:
np_dtype = np.float16
else:
assert False, f"Unsupport dtype = {dtype}, can not convert to numpy dtype"
return np.frombuffer(binary_data, np_dtype, offset=(ndims + 3) * 4).reshape(*dims)
def save_tensor(tensor, file):
with open(file, "wb") as f:
typeid = 0
if tensor.dtype == np.float32:
typeid = 0
elif tensor.dtype == np.float16:
typeid = 1
elif tensor.dtype == np.int32:
typeid = 2
elif tensor.dtype == np.uint8:
typeid = 3
head = np.array([0xFCCFE2E2, tensor.ndim, typeid], dtype=np.uint32).tobytes()
f.write(head)
f.write(np.array(tensor.shape, dtype=np.uint32).tobytes())
f.write(tensor.tobytes())