濮阳杆衣贸易有限公司

主頁(yè) > 知識(shí)庫(kù) > pytorch MSELoss計(jì)算平均的實(shí)現(xiàn)方法

pytorch MSELoss計(jì)算平均的實(shí)現(xiàn)方法

熱門(mén)標(biāo)簽:哈爾濱ai外呼系統(tǒng)定制 騰訊外呼線路 廣告地圖標(biāo)注app 陜西金融外呼系統(tǒng) 海南400電話如何申請(qǐng) 公司電話機(jī)器人 唐山智能外呼系統(tǒng)一般多少錢(qián) 激戰(zhàn)2地圖標(biāo)注 白銀外呼系統(tǒng)

給定損失函數(shù)的輸入y,pred,shape均為bxc。

若設(shè)定loss_fn = torch.nn.MSELoss(reduction='mean'),最終的輸出值其實(shí)是(y - pred)每個(gè)元素?cái)?shù)字的平方之和除以(bxc),也就是在batch和特征維度上都取了平均。

如果只想在batch上做平均,可以這樣寫(xiě):

loss_fn = torch.nn.MSELoss(reduction='sum')
loss = loss_fn(pred, y) / pred.size(0)

補(bǔ)充:PyTorch中MSELoss的使用

參數(shù)

torch.nn.MSELoss(size_average=None, reduce=None, reduction: str = 'mean')

size_average和reduce在當(dāng)前版本的pytorch已經(jīng)不建議使用了,只設(shè)置reduction就行了。

reduction的可選參數(shù)有:'none' 、'mean' 、'sum'

reduction='none':求所有對(duì)應(yīng)位置的差的平方,返回的仍然是一個(gè)和原來(lái)形狀一樣的矩陣。

reduction='mean':求所有對(duì)應(yīng)位置差的平方的均值,返回的是一個(gè)標(biāo)量。

reduction='sum':求所有對(duì)應(yīng)位置差的平方的和,返回的是一個(gè)標(biāo)量。

更多可查看官方文檔​

舉例

首先假設(shè)有三個(gè)數(shù)據(jù)樣本分別經(jīng)過(guò)神經(jīng)網(wǎng)絡(luò)運(yùn)算,得到三個(gè)輸出與其標(biāo)簽分別是:

y_pre = torch.Tensor([[1, 2, 3],
                      [2, 1, 3],
                      [3, 1, 2]])

y_label = torch.Tensor([[1, 0, 0],
                        [0, 1, 0],
                        [0, 0, 1]])

如果reduction='none':

criterion1 = nn.MSELoss(reduction='none')
loss1 = criterion1(x, y)
print(loss1)

則輸出:

tensor([[0., 4., 9.],

[4., 0., 9.],

[9., 1., 1.]])

如果reduction='mean':

criterion2 = nn.MSELoss(reduction='mean')
loss2 = criterion2(x, y)
print(loss2)

則輸出:

tensor(4.1111)

如果reduction='sum':

criterion3 = nn.MSELoss(reduction='sum')
loss3 = criterion3(x, y)
print(loss3)

則輸出:

tensor(37.)

在反向傳播時(shí)的使用

一般在反向傳播時(shí),都是先求loss,再使用loss.backward()求loss對(duì)每個(gè)參數(shù) w_ij和b的偏導(dǎo)數(shù)(也可以理解為梯度)。

這里要注意的是,只有標(biāo)量才能執(zhí)行backward()函數(shù),因此在反向傳播中reduction不能設(shè)為'none'。

但具體設(shè)置為'sum'還是'mean'都是可以的。

若設(shè)置為'sum',則有Loss=loss_1+loss_2+loss_3,表示總的Loss由每個(gè)實(shí)例的loss_i構(gòu)成,在通過(guò)Loss求梯度時(shí),將每個(gè)loss_i的梯度也都考慮進(jìn)去了。

若設(shè)置為'mean',則相比'sum'相當(dāng)于Loss變成了Loss*(1/i),這在參數(shù)更新時(shí)影響不大,因?yàn)橛袑W(xué)習(xí)率a的存在。

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。

您可能感興趣的文章:
  • Pytorch中accuracy和loss的計(jì)算知識(shí)點(diǎn)總結(jié)
  • 基于MSELoss()與CrossEntropyLoss()的區(qū)別詳解
  • 解決Pytorch訓(xùn)練過(guò)程中l(wèi)oss不下降的問(wèn)題
  • Pytorch 的損失函數(shù)Loss function使用詳解

標(biāo)簽:益陽(yáng) 惠州 上海 黑龍江 常德 四川 鷹潭 黔西

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《pytorch MSELoss計(jì)算平均的實(shí)現(xiàn)方法》,本文關(guān)鍵詞  pytorch,MSELoss,計(jì)算,平均,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《pytorch MSELoss計(jì)算平均的實(shí)現(xiàn)方法》相關(guān)的同類(lèi)信息!
  • 本頁(yè)收集關(guān)于pytorch MSELoss計(jì)算平均的實(shí)現(xiàn)方法的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章
    洛阳市| 九龙县| 恭城| 三河市| 广灵县| 万荣县| 湾仔区| 秭归县| 芜湖市| 故城县| 临邑县| 朔州市| 阆中市| 苏尼特左旗| 黄冈市| 榆中县| 汉寿县| 遵义县| 静乐县| 政和县| 晋江市| 咸宁市| 蒙阴县| 西平县| 无锡市| 汪清县| 平定县| 香格里拉县| 柘城县| 漯河市| 白沙| 平江县| 阜新| 宜都市| 望都县| 兴文县| 景东| 通河县| 青阳县| 龙里县| 潼南县|