問題:
自己寫了個(gè)dataloader,為了部署方便,用OpenCV的接口進(jìn)行數(shù)據(jù)讀取,而沒有用PIL,代碼大致如下:
def __getitem__(self, idx):
sample = self.samples[idx]
img = cv2.imread(sample[0])
img = cv2.resize(img, tuple(self.input_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# if not self.val and random.randint(1, 10) 3:
# img = self.img_aug(img)
img = Image.fromarray(img)
img = self.transforms(img)
...
結(jié)果在訓(xùn)練過程中,在第1個(gè)epoch的最后一個(gè)batch時(shí),程序卡死。
解決方案:
可能是因?yàn)镺penCV與Pytorch互鎖的問題,關(guān)閉OpenCV的多線程,問題解決。
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
補(bǔ)充:pytorch 中一個(gè)batch的訓(xùn)練過程
# 一般情況下
optimizer.zero_grad() # 梯度清零
preds = model(inputs) # inference,前向傳播求出預(yù)測值
loss = criterion(preds, targets) # 計(jì)算loss
loss.backward() # 反向傳播求解梯度
optimizer.step() # 更新權(quán)重,更新網(wǎng)絡(luò)權(quán)重參數(shù)
此外,反向傳播前,如果不進(jìn)行梯度清零,則可以實(shí)現(xiàn)梯度累加,從而一定程度上解決顯存受限的問題。
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
您可能感興趣的文章:- pytorch鎖死在dataloader(訓(xùn)練時(shí)卡死)
- pytorch Dataset,DataLoader產(chǎn)生自定義的訓(xùn)練數(shù)據(jù)案例
- 解決Pytorch dataloader時(shí)報(bào)錯(cuò)每個(gè)tensor維度不一樣的問題
- pytorch中DataLoader()過程中遇到的一些問題
- Pytorch 如何加速Dataloader提升數(shù)據(jù)讀取速度
- pytorch DataLoader的num_workers參數(shù)與設(shè)置大小詳解
- pytorch 實(shí)現(xiàn)多個(gè)Dataloader同時(shí)訓(xùn)練