濮阳杆衣贸易有限公司

主頁(yè) > 知識(shí)庫(kù) > Pytorch distributed 多卡并行載入模型操作

Pytorch distributed 多卡并行載入模型操作

熱門(mén)標(biāo)簽:打電話機(jī)器人營(yíng)銷 聊城語(yǔ)音外呼系統(tǒng) 孝感營(yíng)銷電話機(jī)器人效果怎么樣 ai電銷機(jī)器人的優(yōu)勢(shì) 南陽(yáng)打電話機(jī)器人 地圖標(biāo)注自己和別人標(biāo)注區(qū)別 商家地圖標(biāo)注海報(bào) 騰訊地圖標(biāo)注沒(méi)法顯示 海外網(wǎng)吧地圖標(biāo)注注冊(cè)

一、Pytorch distributed 多卡并行載入模型

這次來(lái)介紹下如何載入模型。

目前沒(méi)有找到官方的distribute 載入模型的方式,所以采用如下方式。

大部分情況下,我們?cè)跍y(cè)試時(shí)不需要多卡并行計(jì)算。

所以,我在測(cè)試時(shí)只使用單卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel進(jìn)行單機(jī)多卡訓(xùn)練

One_導(dǎo)入庫(kù):

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_進(jìn)程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要參數(shù)
# local_rank:系統(tǒng)自動(dòng)賦予的進(jìn)程編號(hào),可以利用該編號(hào)控制打印輸出以及設(shè)置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所創(chuàng)建的進(jìn)程數(shù),也就是所使用的GPU數(shù)量
# (初始化設(shè)置詳見(jiàn)參考文檔)

Three_數(shù)據(jù)分發(fā):

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler來(lái)為各個(gè)進(jìn)程分發(fā)數(shù)據(jù),其中num_replicas與world_size保持一致,用于將數(shù)據(jù)集等分成不重疊的數(shù)個(gè)子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler時(shí),其中的shuffle必須為False,而DistributedSampler中的shuffle項(xiàng)默認(rèn)為T(mén)rue,因此訓(xùn)練過(guò)程默認(rèn)執(zhí)行shuffle

Four_網(wǎng)絡(luò)模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 設(shè)置每個(gè)進(jìn)程對(duì)應(yīng)的GPU設(shè)備

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在訓(xùn)練過(guò)程中各卡的前向后向傳播均獨(dú)立進(jìn)行,因此無(wú)法進(jìn)行統(tǒng)一的批歸一化,如果想要將各卡的輸出統(tǒng)一進(jìn)行批歸一化,需要將模型中的BN轉(zhuǎn)換成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在計(jì)算loss的計(jì)算圖里,那么需要find_unused_parameters=True,即返回值不進(jìn)入backward去算grad,也不需要在不同進(jìn)程之間進(jìn)行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每個(gè)epoch需要為sampler設(shè)置當(dāng)前epoch

Six_加載:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加載模型前后用dist.barrier()來(lái)同步不同進(jìn)程間的快慢

Seven_啟動(dòng):

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch啟動(dòng),nproc_per_node為所使用的卡數(shù),batchsize設(shè)置為每張卡各自的批大小

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • pytorch DistributedDataParallel 多卡訓(xùn)練結(jié)果變差的解決方案
  • PyTorch 多GPU下模型的保存與加載(踩坑筆記)
  • pytorch多GPU并行運(yùn)算的實(shí)現(xiàn)

標(biāo)簽:楊凌 六盤(pán)水 撫州 聊城 迪慶 揚(yáng)州 南寧 牡丹江

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《Pytorch distributed 多卡并行載入模型操作》,本文關(guān)鍵詞  Pytorch,distributed,多卡,并行,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《Pytorch distributed 多卡并行載入模型操作》相關(guān)的同類信息!
  • 本頁(yè)收集關(guān)于Pytorch distributed 多卡并行載入模型操作的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章
    巫溪县| 裕民县| 饶平县| 弋阳县| 黄山市| 日照市| 日土县| 安西县| 苏尼特右旗| 竹溪县| 侯马市| 莫力| 天柱县| 南平市| 白朗县| 土默特右旗| 屏山县| 海城市| 博兴县| 陇川县| 堆龙德庆县| 逊克县| 巫山县| 囊谦县| 抚远县| 若尔盖县| 平山县| 武宁县| 罗江县| 阿瓦提县| 清河县| 邢台县| 吴旗县| 石屏县| 诏安县| 营口市| 贵德县| 旌德县| 宁陵县| 大城县| 六安市|