torch.where() 用于將兩個(gè)broadcastable的tensor組合成新的tensor,類似于c++中的三元操作符“?:”
區(qū)別于python numpy中的where()直接可以找到特定條件元素的index
![](/d/20211017/d471aac901290c30200b553b536009c1.gif)
想要實(shí)現(xiàn)numpy中where()的功能,可以借助nonzero()
![](/d/20211017/850911123e2cdc80da7027b556ab2ada.gif)
對(duì)應(yīng)numpy中的where()操作效果:
![](/d/20211017/dabebb21baeaef53fe7ebf25046313cc.gif)
補(bǔ)充:Pytorch torch.Tensor.detach()方法的用法及修改指定模塊權(quán)重的方法
detach
detach的中文意思是分離,官方解釋是返回一個(gè)新的Tensor,從當(dāng)前的計(jì)算圖中分離出來(lái)
![](/d/20211017/7129eaefb3a70b7c28a8877749217a87.gif)
需要注意的是,返回的Tensor和原Tensor共享相同的存儲(chǔ)空間,但是返回的 Tensor 永遠(yuǎn)不會(huì)需要梯度
![](/d/20211017/ce6991971175182e920ca550fafe7c91.gif)
import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
那么這個(gè)函數(shù)有什么作用?
–假如A網(wǎng)絡(luò)輸出了一個(gè)Tensor類型的變量a, a要作為輸入傳入到B網(wǎng)絡(luò)中,如果我想通過(guò)損失函數(shù)反向傳播修改B網(wǎng)絡(luò)的參數(shù),但是不想修改A網(wǎng)絡(luò)的參數(shù),這個(gè)時(shí)候就可以使用detcah()方法
a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()
來(lái)看一個(gè)實(shí)際的例子:
import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad #True
y = t.ones(1, requires_grad=True)
y.requires_grad #True
x = x.detach() #分離之后
x.requires_grad #False
y = x+y #tensor([2.])
y.requires_grad #我還是True
y.retain_grad() #y不是葉子張量,要加上這一行
z = t.pow(y, 2)
z.backward() #反向傳播
y.grad #tensor([4.])
x.grad #None
以上代碼就說(shuō)明了反向傳播到y(tǒng)就結(jié)束了,沒(méi)有到達(dá)x,所以x的grad屬性為None
既然談到了修改模型的權(quán)重問(wèn)題,那么還有一種情況是:
–假如A網(wǎng)絡(luò)輸出了一個(gè)Tensor類型的變量a, a要作為輸入傳入到B網(wǎng)絡(luò)中,如果我想通過(guò)損失函數(shù)反向傳播修改A網(wǎng)絡(luò)的參數(shù),但是不想修改B網(wǎng)絡(luò)的參數(shù),這個(gè)時(shí)候又應(yīng)該怎么辦了?
這時(shí)可以使用Tensor.requires_grad屬性,只需要將requires_grad修改為False即可.
for param in B.parameters():
param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。
您可能感興趣的文章:- Python深度學(xué)習(xí)之使用Pytorch搭建ShuffleNetv2
- win10系統(tǒng)配置GPU版本Pytorch的詳細(xì)教程
- 淺談pytorch中的nn.Sequential(*net[3: 5])是啥意思
- pytorch visdom安裝開(kāi)啟及使用方法
- PyTorch CUDA環(huán)境配置及安裝的步驟(圖文教程)
- pytorch中的nn.ZeroPad2d()零填充函數(shù)實(shí)例詳解
- 使用pytorch實(shí)現(xiàn)線性回歸
- pytorch實(shí)現(xiàn)線性回歸以及多元回歸
- pytorch顯存一直變大的解決方案
- 在Windows下安裝配置CPU版的PyTorch的方法
- PyTorch兩種安裝方法
- PyTorch的Debug指南