pytorch入門
讀取資料
需import的套件
1 | torch.utils.data.Dataset |
1 | dataset = MyDataset(file) |
MyDataset屬於自定義的類別,定義如下:
1
2
3
4
5
6
7class MyDataset(Dataset):
def __init__(self,file):
self.data = ...
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)pytorch裡面的資料大多是以tensor來表示(向量)
我們除了讀取資料,也可以自行創建,範例如下
1
2
3x = torch.tensor([1,2],[3,4]) // 得到2x2
x = torch.from_numpy(np.array([1,2],[3,4])) //也可以自numpy轉移矩陣
x torch.zeros([2,2]) //創建一個shape為(2,2)的全0矩陣
常用函數
transpose : 讓兩個維度之間互換
範例1
2
3x.shape() == (2,3)
x = x.transpose(0,1)
x.shape() == (3,2)x.squeeze(0):將一個大小為1維度消除
相對的,呼叫x.unsqueeze(0)就是增加一個大小=1的維度
torch.cat:將一個維度連接起來(其餘維度大小需相同)
常見問題
- data type
- 運算用的處理器
x = x.to('...')