pytorch入門

讀取資料

需import的套件

1
2
torch.utils.data.Dataset
torch.utils.data.DataLoader
1
2
3
dataset = MyDataset(file)
dataloader = DataLoader(dataset, batch_size, shuffle=True) //Training用true, Testing用false

  • MyDataset屬於自定義的類別,定義如下:

    1
    2
    3
    4
    5
    6
    7
    class MyDataset(Dataset):
    def __init__(self,file):
    self.data = ...
    def __getitem__(self, index):
    return self.data[index]
    def __len__(self):
    return len(self.data)
  • pytorch裡面的資料大多是以tensor來表示(向量)

  • 我們除了讀取資料,也可以自行創建,範例如下

    1
    2
    3
    x = torch.tensor([1,2],[3,4]) // 得到2x2
    x = torch.from_numpy(np.array([1,2],[3,4])) //也可以自numpy轉移矩陣
    x torch.zeros([2,2]) //創建一個shape為(2,2)的全0矩陣

常用函數

  • transpose : 讓兩個維度之間互換
    範例

    1
    2
    3
    x.shape() == (2,3)
    x = x.transpose(0,1)
    x.shape() == (3,2)
  • x.squeeze(0):將一個大小為1維度消除

  • 相對的,呼叫x.unsqueeze(0)就是增加一個大小=1的維度

  • torch.cat:將一個維度連接起來(其餘維度大小需相同)

常見問題

  • data type
  • 運算用的處理器 x = x.to('...')