濮阳杆衣贸易有限公司

主頁(yè) > 知識(shí)庫(kù) > Pytorch學(xué)習(xí)筆記DCGAN極簡(jiǎn)入門(mén)教程

Pytorch學(xué)習(xí)筆記DCGAN極簡(jiǎn)入門(mén)教程

熱門(mén)標(biāo)簽:佛山防封外呼系統(tǒng)收費(fèi) 哈爾濱外呼系統(tǒng)代理商 鄭州智能外呼系統(tǒng)運(yùn)營(yíng)商 電話(huà)機(jī)器人適用業(yè)務(wù) 徐州天音防封電銷(xiāo)卡 不錯(cuò)的400電話(huà)辦理 湛江電銷(xiāo)防封卡 獲客智能電銷(xiāo)機(jī)器人 南昌辦理400電話(huà)怎么安裝

1.圖片分類(lèi)網(wǎng)絡(luò)

這是一個(gè)二分類(lèi)網(wǎng)絡(luò),可以是alxnet ,vgg,resnet任何一個(gè),負(fù)責(zé)對(duì)圖片進(jìn)行二分類(lèi),區(qū)分圖片是真實(shí)圖片還是生成的圖片

2.圖片生成網(wǎng)絡(luò)

輸入是一個(gè)隨機(jī)噪聲,輸出是一張圖片,使用的是反卷積層

相信學(xué)過(guò)深度學(xué)習(xí)的都能寫(xiě)出這兩個(gè)網(wǎng)絡(luò),當(dāng)然如果你寫(xiě)不出來(lái),沒(méi)關(guān)系,有人替你寫(xiě)好了

首先是圖片分類(lèi)網(wǎng)絡(luò):

簡(jiǎn)單來(lái)說(shuō)就是cnn+relu+sogmid,可以換成任何一個(gè)分類(lèi)網(wǎng)絡(luò),比如bgg,resnet等

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.main(input)

重點(diǎn)是生成網(wǎng)絡(luò)

代碼如下,其實(shí)就是反卷積+bn+relu

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
    def forward(self, input):
        return self.main(input)


講道理,以上兩個(gè)網(wǎng)絡(luò)都挺簡(jiǎn)單。

真正的重點(diǎn)到了,怎么訓(xùn)練

每一個(gè)step分為三個(gè)步驟:

  • 訓(xùn)練二分類(lèi)網(wǎng)絡(luò)
    1.輸入真實(shí)圖片,經(jīng)過(guò)二分類(lèi),希望判定為真實(shí)圖片,更新二分類(lèi)網(wǎng)絡(luò)
    2.輸入噪聲,進(jìn)過(guò)生成網(wǎng)絡(luò),生成一張圖片,輸入二分類(lèi)網(wǎng)絡(luò),希望判定為虛假圖片,更新二分類(lèi)網(wǎng)絡(luò)
  • 訓(xùn)練生成網(wǎng)絡(luò)
    3.輸入噪聲,進(jìn)過(guò)生成網(wǎng)絡(luò),生成一張圖片,輸入二分類(lèi)網(wǎng)絡(luò),希望判定為真實(shí)圖片,更新生成網(wǎng)絡(luò)

不多說(shuō)直接上代碼

for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()
        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        iters += 1

以上就是Pytorch學(xué)習(xí)筆記DCGAN極簡(jiǎn)入門(mén)教程的詳細(xì)內(nèi)容,更多關(guān)于Pytorch學(xué)習(xí)DCGAN入門(mén)教程的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

您可能感興趣的文章:
  • Pytorch使用MNIST數(shù)據(jù)集實(shí)現(xiàn)基礎(chǔ)GAN和DCGAN詳解
  • PyTorch安裝與基本使用詳解
  • 使用Pytorch搭建模型的步驟

標(biāo)簽:蕪湖 紹興 廣西 呂梁 懷化 蘭州 吉安 安康

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《Pytorch學(xué)習(xí)筆記DCGAN極簡(jiǎn)入門(mén)教程》,本文關(guān)鍵詞  Pytorch,學(xué)習(xí),筆記,DCGAN,極簡(jiǎn),;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《Pytorch學(xué)習(xí)筆記DCGAN極簡(jiǎn)入門(mén)教程》相關(guān)的同類(lèi)信息!
  • 本頁(yè)收集關(guān)于Pytorch學(xué)習(xí)筆記DCGAN極簡(jiǎn)入門(mén)教程的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章
    武冈市| 曲周县| 龙山县| 婺源县| 丰都县| 盐亭县| 桂林市| 凌云县| 日喀则市| 新源县| 阿瓦提县| 满洲里市| 察隅县| 长乐市| 伊宁县| 罗甸县| 夏邑县| 板桥市| 确山县| 宝鸡市| 十堰市| 佛冈县| 浪卡子县| 阳东县| 白银市| 读书| 乌审旗| 寻乌县| 兴安县| 珠海市| 深圳市| 桦甸市| 特克斯县| 横山县| 上杭县| 大丰市| 云和县| 雷波县| 阳泉市| 佳木斯市| 重庆市|