看代碼吧~
import torch
import torch.utils.data as Data
torch.manual_seed(1) # reproducible
# BATCH_SIZE = 5
BATCH_SIZE = 8 # 每次使用8個數(shù)據(jù)同時傳入網(wǎng)路
x = torch.linspace(1, 10, 10) # this is x data (torch tensor)
y = torch.linspace(10, 1, 10) # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=False, # 設(shè)置不隨機(jī)打亂數(shù)據(jù) random shuffle for training
num_workers=2, # 使用兩個進(jìn)程提取數(shù)據(jù),subprocesses for loading data
)
def show_batch():
for epoch in range(3): # 全部的數(shù)據(jù)使用3遍,train entire dataset 3 times
for step, (batch_x, batch_y) in enumerate(loader): # for each training step
# train your data...
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
show_batch()
BATCH_SIZE = 8 , 所有數(shù)據(jù)利用三次
Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 0 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 1 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]
Epoch: 2 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
補(bǔ)充:pytorch批訓(xùn)練bug
問題描述:
在進(jìn)行pytorch神經(jīng)網(wǎng)絡(luò)批訓(xùn)練的時候,有時會出現(xiàn)報錯
TypeError: batch must contain tensors, numbers, dicts or lists; found class 'torch.autograd.variable.Variable'>
解決辦法:
第一步:
檢查(重點(diǎn)!?。。?!):
train_dataset = Data.TensorDataset(train_x, train_y)
train_x,和train_y格式,要求是tensor類,我第一次出錯就是因?yàn)閭魅氲氖莢ariable
可以這樣將數(shù)據(jù)變?yōu)閠ensor類:
train_x = torch.FloatTensor(train_x)
第二步:
train_loader = Data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True
)
實(shí)例化一個DataLoader對象
第三步:
for epoch in range(epochs):
for step, (batch_x, batch_y) in enumerate(train_loader):
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
這樣就可以批訓(xùn)練了
需要注意的是:train_loader輸出的是tensor,在訓(xùn)練網(wǎng)絡(luò)時,需要變成Variable
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
您可能感興趣的文章:- 詳解PyTorch批訓(xùn)練及優(yōu)化器比較
- pytorch 固定部分參數(shù)訓(xùn)練的方法
- pytorch 準(zhǔn)備、訓(xùn)練和測試自己的圖片數(shù)據(jù)的方法
- pytorch 在網(wǎng)絡(luò)中添加可訓(xùn)練參數(shù),修改預(yù)訓(xùn)練權(quán)重文件的方法