Pytorch The basic data structure of is tensor Tensor. A tensor is a multidimensional array .Pytorch The tensor and numpy Medium array Is very similar .
In this section, we mainly introduce the data types of tensors 、 The dimensions of tensors 、 The size of the tensor 、 Tensor sum numpy Basic concepts such as arrays .
The data types of tensors and numpy.array Basically one-to-one correspondence , But not supported str type .
Include :
torch.float64(torch.double),
torch.float32(torch.float),
torch.float16,
torch.int64(torch.long),
torch.int32(torch.int),
torch.int16,
torch.int8,
torch.uint8,
torch.bool
General neural network modeling uses torch.float32 type .
i = torch.IntTensor(1);print(i,i.dtype)
x = torch.Tensor(np.array(2.0));print(x,x.dtype) # Equivalent to torch.FloatTensor
b = torch.BoolTensor(np.array([1,0,2,0])); print(b,b.dtype)
# Different types are converted
i = torch.tensor(1); print(i,i.dtype)
x = i.float(); print(x,x.dtype) # call float Method to floating point type
y = i.type(torch.float); print(y,y.dtype) # Use type Function conversion to floating point type
z = i.type_as(x);print(z,z.dtype) # Use type_as Method is converted to some Tensor The same type
Different types of data can have different dimensions (dimension) To express .
The scalar is 0 D tensor , The vector is 1 D tensor , The matrix of the 2 D tensor .
Color images have rgb Three channels , It can be expressed as 3 D tensor .
Video and time dimension , It can be expressed as 4 D tensor .
It can be simply summed up as : There are several layers of brackets , It's just how many dimensional tensors .
tensor3 = torch.tensor([[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]) # 3 D tensor
print(tensor3)
print(tensor3.dim())
have access to shape Property or size() Method to see the length of the tensor in each dimension .
have access to view Methods change the size of the tensor .
If view Method change size failed , have access to reshape Method .
scalar = torch.tensor(True)
print(scalar.size())
print(vector.shape)
# Use view You can change the tensor size
vector = torch.arange(0,12)
print(vector)
print(vector.shape)
matrix34 = vector.view(3,4)
print(matrix34)
print(matrix34.shape)
matrix43 = vector.view(4,-1) #-1 Indicates that the length of the position is automatically inferred by the program
print(matrix43)
print(matrix43.shape)
# Some operations distort the tensor storage structure , Use it directly view Will fail , It can be used reshape Method
matrix26 = torch.arange(0,12).view(2,6)
print(matrix26)
print(matrix26.shape)
# Transpose operation distorts tensor storage structure
matrix62 = matrix26.t()
print(matrix62.is_contiguous())
# Use it directly view The method will fail , have access to reshape Method
#matrix34 = matrix62.view(3,4) #error!
matrix34 = matrix62.reshape(3,4) # Equivalent to matrix34 = matrix62.contiguous().view(3,4)
print(matrix34)
It can be used numpy Methods from Tensor obtain numpy Array , It can also be used. torch.from_numpy from numpy Array Tensor.
this The two methods are related Tensor and numpy Arrays are shared data memory .
If you change one of them , The value of the other one will also change .
If necessary , It can be used Tensor clone Method copy tensor , Break the link .
Besides , You can also use item Methods get the corresponding from scalar tensor Python The number .
Use tolist Methods get the corresponding from the tensor Python List of values .
import numpy as np
import torch
#torch.from_numpy Function from numpy Array Tensor
arr = np.zeros(3)
tensor = torch.from_numpy(arr)
print("before add 1:")
print(arr)
print(tensor)
print("\nafter add 1:")
np.add(arr,1, out = arr) # to arr increase 1,tensor It also changes
print(arr)
print(tensor)
# numpy Methods from Tensor obtain numpy Array
tensor = torch.zeros(3)
arr = tensor.numpy()
print("before add 1:")
print(tensor)
print(arr)
print("\nafter add 1:")
# Use the underlined method to indicate that the calculation result is returned to the call tensor
tensor.add_(1) # to tensor increase 1,arr It also changes
# or : torch.add(tensor,1,out = tensor)
print(tensor)
print(arr)