背景:
基于PyTorch的模型,想固定主分支參數(shù),只訓(xùn)練子分支,結(jié)果發(fā)現(xiàn)在不同epoch相同的測試數(shù)據(jù)經(jīng)過主分支輸出的結(jié)果不同。
原因:
未固定主分支BN層中的running_mean和running_var。
解決方法:
將需要固定的BN層狀態(tài)設(shè)置為eval。
問題示例:
環(huán)境:torch:1.7.0
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 16, 3)
self.bn2 = nn.BatchNorm2d(16)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 5)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
def print_parameter_grad_info(net):
print('-------parameters requires grad info--------')
for name, p in net.named_parameters():
print(f'{name}:\t{p.requires_grad}')
def print_net_state_dict(net):
for key, v in net.state_dict().items():
print(f'{key}')
if __name__ == "__main__":
net = Net()
print_parameter_grad_info(net)
net.requires_grad_(False)
print_parameter_grad_info(net)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假設(shè)每個epoch只迭代一次
net.train()
pre = net(train_data)
# 計算損失和參數(shù)更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
運(yùn)行結(jié)果:
-------parameters requires grad info--------
conv1.weight: True
conv1.bias: True
bn1.weight: True
bn1.bias: True
conv2.weight: True
conv2.bias: True
bn2.weight: True
bn2.bias: True
fc1.weight: True
fc1.bias: True
fc2.weight: True
fc2.bias: True
fc3.weight: True
fc3.bias: True
-------parameters requires grad info--------
conv1.weight: False
conv1.bias: False
bn1.weight: False
bn1.bias: False
conv2.weight: False
conv2.bias: False
bn2.weight: False
bn2.bias: False
fc1.weight: False
fc1.bias: False
fc2.weight: False
fc2.bias: False
fc3.weight: False
fc3.bias: False
epoch:0 tensor([[-0.0755, 0.1138, 0.0966, 0.0564, -0.0224]])
epoch:1 tensor([[-0.0763, 0.1113, 0.0970, 0.0574, -0.0235]])
可以看到:
net.requires_grad_(False)已經(jīng)將網(wǎng)絡(luò)中的各參數(shù)設(shè)置成了不需要梯度更新的狀態(tài),但是同樣的測試數(shù)據(jù)test_data在不同epoch中前向之后出現(xiàn)了不同的結(jié)果。
調(diào)用print_net_state_dict可以看到BN層中的參數(shù)running_mean和running_var并沒在可優(yōu)化參數(shù)net.parameters中
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
但在training pahse的前向過程中,這兩個參數(shù)被更新了。導(dǎo)致整個網(wǎng)絡(luò)在freeze的情況下,同樣的測試數(shù)據(jù)出現(xiàn)了不同的結(jié)果
Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source
因此在training phase時對BN層顯式設(shè)置eval狀態(tài):
if __name__ == "__main__":
net = Net()
net.requires_grad_(False)
torch.random.manual_seed(5)
test_data = torch.rand(1, 1, 32, 32)
train_data = torch.rand(5, 1, 32, 32)
# print(test_data)
# print(train_data[0, ...])
for epoch in range(2):
# training phase, 假設(shè)每個epoch只迭代一次
net.train()
net.bn1.eval()
net.bn2.eval()
pre = net(train_data)
# 計算損失和參數(shù)更新等
# ....
# test phase
net.eval()
x = net(test_data)
print(f'epoch:{epoch}', x)
可以看到結(jié)果正常了:
epoch:0 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372, 0.0059, -0.0625, -0.0048]])
補(bǔ)充:pytorch---之BN層參數(shù)詳解及應(yīng)用(1,2,3)(1,2)?
BN層參數(shù)詳解(1,2)
一般來說pytorch中的模型都是繼承nn.Module類的,都有一個屬性trainning指定是否是訓(xùn)練狀態(tài),訓(xùn)練狀態(tài)與否將會影響到某些層的參數(shù)是否是固定的,比如BN層(對于BN層測試的均值和方差是通過統(tǒng)計訓(xùn)練的時候所有的batch的均值和方差的平均值)或者Dropout層(對于Dropout層在測試的時候所有神經(jīng)元都是激活的)。通常用model.train()指定當(dāng)前模型model為訓(xùn)練狀態(tài),model.eval()指定當(dāng)前模型為測試狀態(tài)。
同時,BN的API中有幾個參數(shù)需要比較關(guān)心的,一個是affine指定是否需要仿射,還有個是track_running_stats指定是否跟蹤當(dāng)前batch的統(tǒng)計特性。容易出現(xiàn)問題也正好是這三個參數(shù):trainning,affine,track_running_stats。
其中的affine指定是否需要仿射,也就是是否需要上面算式的第四個,如果affine=False則γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能學(xué)習(xí)被更新。一般都會設(shè)置成affine=True。(這里是一個可學(xué)習(xí)參數(shù))
trainning和track_running_stats,track_running_stats=True表示跟蹤整個訓(xùn)練過程中的batch的統(tǒng)計特性,得到方差和均值,而不只是僅僅依賴與當(dāng)前輸入的batch的統(tǒng)計特性(意思就是說新的batch依賴于之前的batch的均值和方差這里使用momentum參數(shù),參考了指數(shù)移動平均的算法EMA)。相反的,如果track_running_stats=False那么就只是計算當(dāng)前輸入的batch的統(tǒng)計特性中的均值和方差了。當(dāng)在推理階段的時候,如果track_running_stats=False,此時如果batch_size比較小,那么其統(tǒng)計特性就會和全局統(tǒng)計特性有著較大偏差,可能導(dǎo)致糟糕的效果。
應(yīng)用技巧:(1,2)
通常pytorch都會用到optimizer.zero_grad() 來清空以前的batch所累加的梯度,因?yàn)閜ytorch中Variable計算的梯度會進(jìn)行累計,所以每一個batch都要重新清空一次梯度,原始的做法是下面這樣的:
問題:參數(shù)non_blocking,以及pytorch的整體框架??
代碼(1)
for index,data,target in enumerate(dataloader):
data = data.cuda(non_blocking=True)
target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = Trye)
output = model(data)
loss = criterion(output,target)
#清空梯度
optimizer.zero_grad()
loss.backward()
optimizer.step()
而這里為了模仿minibacth,我們每次batch不清0,累積到一定次數(shù)再清0,再更新權(quán)重:
for index, data, target in enumerate(dataloader):
#如果不是Tensor,一般要用到torch.from_numpy()
data = data.cuda(non_blocking = True)
target = torch.from_numpy(np.array(target)).float().cuda(non_blocking = True)
output = model(data)
loss = criterion(data, target)
loss.backward()
if index%accumulation == 0:
#用累積的梯度更新權(quán)重
optimizer.step()
#清空梯度
optimizer.zero_grad()
雖然這里的梯度是相當(dāng)于原來的accumulation倍,但是實(shí)際在前向傳播的過程中,對于BN幾乎沒有影響,因?yàn)榍跋虻腂N還是只是一個batch的均值和方差,這個時候可以用pytorch中BN的momentum參數(shù),默認(rèn)是0.1,BN參數(shù)如下,就是指數(shù)移動平均
x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
您可能感興趣的文章:- pytorch 如何自定義卷積核權(quán)值參數(shù)
- pytorch交叉熵?fù)p失函數(shù)的weight參數(shù)的使用
- Pytorch 統(tǒng)計模型參數(shù)量的操作 param.numel()
- pytorch 一行代碼查看網(wǎng)絡(luò)參數(shù)總量的實(shí)現(xiàn)
- pytorch查看網(wǎng)絡(luò)參數(shù)顯存占用量等操作
- pytorch 優(yōu)化器(optim)不同參數(shù)組,不同學(xué)習(xí)率設(shè)置的操作
- pytorch LayerNorm參數(shù)的用法及計算過程