濮阳杆衣贸易有限公司

主頁(yè) > 知識(shí)庫(kù) > pytorch交叉熵?fù)p失函數(shù)的weight參數(shù)的使用

pytorch交叉熵?fù)p失函數(shù)的weight參數(shù)的使用

熱門(mén)標(biāo)簽:400電話辦理哪種 開(kāi)封語(yǔ)音外呼系統(tǒng)代理商 地圖標(biāo)注線上如何操作 應(yīng)電話機(jī)器人打電話違法嗎 河北防封卡電銷(xiāo)卡 開(kāi)封自動(dòng)外呼系統(tǒng)怎么收費(fèi) 手機(jī)網(wǎng)頁(yè)嵌入地圖標(biāo)注位置 天津電話機(jī)器人公司 電銷(xiāo)機(jī)器人的風(fēng)險(xiǎn)

首先

必須將權(quán)重也轉(zhuǎn)為T(mén)ensor的cuda格式;

然后

將該class_weight作為交叉熵函數(shù)對(duì)應(yīng)參數(shù)的輸入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

補(bǔ)充:關(guān)于pytorch的CrossEntropyLoss的weight參數(shù)

首先這個(gè)weight參數(shù)比想象中的要考慮的多

你可以試試下面代碼

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

這里的手動(dòng)計(jì)算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加權(quán)呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算發(fā)現(xiàn),并不是單純的那權(quán)重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

發(fā)現(xiàn)了么,加權(quán)后,除以的是權(quán)重的和,不是數(shù)目的和。

我們?cè)衮?yàn)證一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人對(duì)loss的CE計(jì)算過(guò)程有疑問(wèn),我這里細(xì)致寫(xiě)寫(xiě)交叉熵的計(jì)算過(guò)程,就拿最后一個(gè)例子的loss4的計(jì)算說(shuō)明

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • PyTorch的SoftMax交叉熵?fù)p失和梯度用法
  • pytorch中常用的損失函數(shù)用法說(shuō)明
  • Pytorch十九種損失函數(shù)的使用詳解
  • pytorch中交叉熵?fù)p失(nn.CrossEntropyLoss())的計(jì)算過(guò)程詳解
  • Python機(jī)器學(xué)習(xí)pytorch交叉熵?fù)p失函數(shù)的深刻理解

標(biāo)簽:成都 常州 山東 宿遷 駐馬店 六盤(pán)水 江蘇 蘭州

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《pytorch交叉熵?fù)p失函數(shù)的weight參數(shù)的使用》,本文關(guān)鍵詞  pytorch,交叉,熵,損失,函數(shù),;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《pytorch交叉熵?fù)p失函數(shù)的weight參數(shù)的使用》相關(guān)的同類(lèi)信息!
  • 本頁(yè)收集關(guān)于pytorch交叉熵?fù)p失函數(shù)的weight參數(shù)的使用的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章
    吉安县| 个旧市| 罗山县| 萨嘎县| 昌宁县| 岳阳市| 汉源县| 丹凤县| 五大连池市| 德庆县| 崇礼县| 涞源县| 台江县| 永吉县| 浏阳市| 师宗县| 曲沃县| 吉安市| 郴州市| 米泉市| 兴安盟| 商城县| 通江县| 南川市| 南投县| 林西县| 二手房| 五寨县| 临湘市| 安丘市| 开远市| 陆川县| 安义县| 晋城| 铜山县| 五峰| 安陆市| 乌苏市| 祁门县| 志丹县| 高唐县|