濮阳杆衣贸易有限公司

主頁 > 知識庫 > Pytorch之如何dropout避免過擬合

Pytorch之如何dropout避免過擬合

熱門標(biāo)簽:北瀚ai電銷機(jī)器人官網(wǎng)手機(jī)版 市場上的電銷機(jī)器人 所得系統(tǒng)電梯怎樣主板設(shè)置外呼 小蘇云呼電話機(jī)器人 儋州電話機(jī)器人 佛山400電話辦理 朝陽手機(jī)外呼系統(tǒng) 北京電銷外呼系統(tǒng)加盟 地圖標(biāo)注面積

一.做數(shù)據(jù)

二.搭建神經(jīng)網(wǎng)絡(luò)

三.訓(xùn)練

四.對比測試結(jié)果

注意:測試過程中,一定要注意模式切換

Pytorch的學(xué)習(xí)——過擬合

過擬合

過擬合是當(dāng)數(shù)據(jù)量較小時(shí)或者輸出結(jié)果過于依賴某些特定的神經(jīng)元,訓(xùn)練神經(jīng)網(wǎng)絡(luò)訓(xùn)練會(huì)發(fā)生一種現(xiàn)象。出現(xiàn)這種現(xiàn)象的神經(jīng)網(wǎng)絡(luò)預(yù)測的結(jié)果并不具有普遍意義,其預(yù)測結(jié)果極不準(zhǔn)確。

解決方法

1.增加數(shù)據(jù)量

2.L1,L2,L3…正規(guī)化,即在計(jì)算誤差值的時(shí)候加上要學(xué)習(xí)的參數(shù)值,當(dāng)參數(shù)改變過大時(shí),誤差也會(huì)變大,通過這種懲罰機(jī)制來控制過擬合現(xiàn)象

3.dropout正規(guī)化,在訓(xùn)練過程中通過隨機(jī)屏蔽部分神經(jīng)網(wǎng)絡(luò)連接,使神經(jīng)網(wǎng)絡(luò)不完整,這樣就可以使神經(jīng)網(wǎng)絡(luò)的預(yù)測結(jié)果不會(huì)過分依賴某些特定的神經(jīng)元

例子

這里小編通過dropout正規(guī)化的列子來更加形象的了解神經(jīng)網(wǎng)絡(luò)的過擬合現(xiàn)象

import torch
import matplotlib.pyplot as plt
N_SAMPLES = 20
N_HIDDEN = 300
# train數(shù)據(jù)
x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)
y = x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))
# test數(shù)據(jù)
test_x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)
test_y = test_x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))
# 可視化
plt.scatter(x.data.numpy(), y.data.numpy(), c='magenta', s=50, alpha=0.5, label='train')
plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='cyan', s=50, alpha=0.5, label='test')
plt.legend(loc='upper left')
plt.ylim((-2.5, 2.5))
plt.show()
# 網(wǎng)絡(luò)一,未使用dropout正規(guī)化
net_overfitting = torch.nn.Sequential(
    torch.nn.Linear(1, N_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, N_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, 1),
)
# 網(wǎng)絡(luò)二,使用dropout正規(guī)化
net_dropped = torch.nn.Sequential(
    torch.nn.Linear(1, N_HIDDEN),
    torch.nn.Dropout(0.5),  # 隨機(jī)屏蔽50%的網(wǎng)絡(luò)連接
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, N_HIDDEN),
    torch.nn.Dropout(0.5),  # 隨機(jī)屏蔽50%的網(wǎng)絡(luò)連接
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, 1),
)
# 選擇優(yōu)化器
optimizer_ofit = torch.optim.Adam(net_overfitting.parameters(), lr=0.01)
optimizer_drop = torch.optim.Adam(net_dropped.parameters(), lr=0.01)
# 選擇計(jì)算誤差的工具
loss_func = torch.nn.MSELoss()
plt.ion()
for t in range(500):
    # 神經(jīng)網(wǎng)絡(luò)訓(xùn)練數(shù)據(jù)的固定過程
    pred_ofit = net_overfitting(x)
    pred_drop = net_dropped(x)
    loss_ofit = loss_func(pred_ofit, y)
    loss_drop = loss_func(pred_drop, y)
    optimizer_ofit.zero_grad()
    optimizer_drop.zero_grad()
    loss_ofit.backward()
    loss_drop.backward()
    optimizer_ofit.step()
    optimizer_drop.step()
    if t % 10 == 0:
        # 脫離訓(xùn)練模式,這里便于展示神經(jīng)網(wǎng)絡(luò)的變化過程
        net_overfitting.eval()
        net_dropped.eval() 
        # 可視化
        plt.cla()
        test_pred_ofit = net_overfitting(test_x)
        test_pred_drop = net_dropped(test_x)
        plt.scatter(x.data.numpy(), y.data.numpy(), c='magenta', s=50, alpha=0.3, label='train')
        plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='cyan', s=50, alpha=0.3, label='test')
        plt.plot(test_x.data.numpy(), test_pred_ofit.data.numpy(), 'r-', lw=3, label='overfitting')
        plt.plot(test_x.data.numpy(), test_pred_drop.data.numpy(), 'b--', lw=3, label='dropout(50%)')
        plt.text(0, -1.2, 'overfitting loss=%.4f' % loss_func(test_pred_ofit, test_y).data.numpy(),
                 fontdict={'size': 20, 'color':  'red'})
        plt.text(0, -1.5, 'dropout loss=%.4f' % loss_func(test_pred_drop, test_y).data.numpy(),
                 fontdict={'size': 20, 'color': 'blue'})
        plt.legend(loc='upper left'); plt.ylim((-2.5, 2.5));plt.pause(0.1)
        # 重新進(jìn)入訓(xùn)練模式,并繼續(xù)上次訓(xùn)練
        net_overfitting.train()
        net_dropped.train()
plt.ioff()
plt.show()

效果

可以看到紅色的線雖然更加擬合train數(shù)據(jù),但是通過test數(shù)據(jù)發(fā)現(xiàn)它的誤差反而比較大

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • keras處理欠擬合和過擬合的實(shí)例講解
  • pytorch構(gòu)建網(wǎng)絡(luò)模型的4種方法
  • 解決Pytorch 加載訓(xùn)練好的模型 遇到的error問題
  • Python機(jī)器學(xué)習(xí)pytorch模型選擇及欠擬合和過擬合詳解

標(biāo)簽:江蘇 云南 商丘 酒泉 定西 金融催收 龍巖 寧夏

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《Pytorch之如何dropout避免過擬合》,本文關(guān)鍵詞  Pytorch,之,如何,dropout,避免,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問題,煩請?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《Pytorch之如何dropout避免過擬合》相關(guān)的同類信息!
  • 本頁收集關(guān)于Pytorch之如何dropout避免過擬合的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章
    含山县| 宜兴市| 文水县| 南京市| 宜都市| 盐源县| 中方县| 兰考县| 巴东县| 泰来县| 吴堡县| 边坝县| 蕲春县| 宣恩县| 鄯善县| 扶绥县| 榆中县| 交城县| 绍兴市| 邛崃市| 永善县| 若羌县| 改则县| 涞源县| 砀山县| 云霄县| 丹棱县| 海晏县| 扎囊县| 蓝田县| 醴陵市| 延寿县| 观塘区| 嘉鱼县| 多伦县| 奉节县| 仙桃市| 昌江| 柘荣县| 定结县| 瓦房店市|