This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import os | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim as optim | |
import torch.utils.data | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | fake') | |
parser.add_argument('--dataroot', required=True, help='path to dataset') | |
parser.add_argument('--workers', type=int, help='number of data loading workers', default=8) | |
parser.add_argument('--batchSize', type=int, default=64, help='input batch size') | |
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') | |
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') | |
parser.add_argument('--ngf', type=int, default=64) | |
parser.add_argument('--ndf', type=int, default=64) | |
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') | |
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') | |
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') | |
parser.add_argument('--cuda', action='store_true', help='enables cuda') | |
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') | |
parser.add_argument('--netG', default='', help="path to netG (to continue training)") | |
parser.add_argument('--netD', default='', help="path to netD (to continue training)") | |
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') | |
parser.add_argument('--manualSeed', type=int, help='manual seed') | |
opt = parser.parse_args() | |
print(opt) | |
try: | |
os.makedirs(opt.outf) | |
except OSError: | |
pass | |
if opt.manualSeed is None: | |
opt.manualSeed = random.randint(1, 10000) | |
print("Random Seed: ", opt.manualSeed) | |
random.seed(opt.manualSeed) | |
torch.manual_seed(opt.manualSeed) | |
cudnn.benchmark = True | |
if torch.cuda.is_available() and not opt.cuda: | |
print("WARNING: You have a CUDA device, so you should probably run with --cuda") | |
if opt.dataset in ['imagenet', 'folder', 'lfw']: | |
# folder dataset | |
dataset = dset.ImageFolder(root=opt.dataroot, | |
transform=transforms.Compose([ | |
transforms.Resize(opt.imageSize), | |
transforms.CenterCrop(opt.imageSize), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
])) | |
elif opt.dataset == 'lsun': | |
dataset = dset.LSUN(root=opt.dataroot, classes=['bedroom_train'], | |
transform=transforms.Compose([ | |
transforms.Resize(opt.imageSize), | |
transforms.CenterCrop(opt.imageSize), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
])) | |
elif opt.dataset == 'cifar10': | |
dataset = dset.CIFAR10(root=opt.dataroot, download=True, | |
transform=transforms.Compose([ | |
transforms.Resize(opt.imageSize), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
])) | |
elif opt.dataset == 'fake': | |
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), | |
transform=transforms.ToTensor()) | |
assert dataset | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, | |
shuffle=True, num_workers=int(opt.workers)) | |
device = torch.device("cuda:0" if opt.cuda else "cpu") | |
ngpu = int(opt.ngpu) | |
nz = int(opt.nz) | |
ngf = int(opt.ngf) | |
ndf = int(opt.ndf) | |
nc = 3 | |
# custom weights initialization called on netG and netD | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
class BatchNorm2d(nn.BatchNorm2d): | |
def forward(self, x): | |
self._check_input_dim(x) | |
y = x.transpose(0,1) | |
return_shape = y.shape | |
y = y.contiguous().view(x.size(1), -1) | |
mu = y.mean(dim=1) | |
sigma2 = y.var(dim=1) | |
if self.training is not True: | |
y = y - self.running_mean.view(-1, 1) | |
y = y / (self.running_var.view(-1, 1)**.5 + self.eps) | |
else: | |
if self.track_running_stats is True: | |
with torch.no_grad(): | |
self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu | |
self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2 | |
y = y - mu.view(-1,1) | |
y = y / (sigma2.view(-1,1)**.5 + self.eps) | |
y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1) | |
return y.view(return_shape).transpose(0,1) | |
track = True | |
class Generator(nn.Module): | |
def __init__(self, ngpu): | |
super(Generator, self).__init__() | |
self.ngpu = ngpu | |
self.main = nn.Sequential( | |
# input is Z, going into a convolution | |
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), | |
BatchNorm2d(ngf * 8, track_running_stats=track), | |
nn.ReLU(True), | |
# state size. (ngf*8) x 4 x 4 | |
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), | |
BatchNorm2d(ngf * 4, track_running_stats=track), | |
nn.ReLU(True), | |
# state size. (ngf*4) x 8 x 8 | |
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), | |
BatchNorm2d(ngf * 2, track_running_stats=track), | |
nn.ReLU(True), | |
# state size. (ngf*2) x 16 x 16 | |
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), | |
BatchNorm2d(ngf, track_running_stats=track), | |
nn.ReLU(True), | |
# state size. (ngf) x 32 x 32 | |
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), | |
nn.Tanh() | |
# state size. (nc) x 64 x 64 | |
) | |
def forward(self, input): | |
if input.is_cuda and self.ngpu > 1: | |
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | |
else: | |
output = self.main(input) | |
return output | |
netG = Generator(ngpu).to(device) | |
netG.apply(weights_init) | |
if opt.netG != '': | |
netG.load_state_dict(torch.load(opt.netG)) | |
print(netG) | |
class Discriminator(nn.Module): | |
def __init__(self, ngpu): | |
super(Discriminator, self).__init__() | |
self.ngpu = ngpu | |
self.main = nn.Sequential( | |
# input is (nc) x 64 x 64 | |
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. (ndf) x 32 x 32 | |
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), | |
BatchNorm2d(ndf * 2, track_running_stats=track), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. (ndf*2) x 16 x 16 | |
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), | |
BatchNorm2d(ndf * 4, track_running_stats=track), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. (ndf*4) x 8 x 8 | |
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), | |
BatchNorm2d(ndf * 8, track_running_stats=track), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. (ndf*8) x 4 x 4 | |
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), | |
nn.Sigmoid() | |
) | |
def forward(self, input): | |
if input.is_cuda and self.ngpu > 1: | |
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | |
else: | |
output = self.main(input) | |
return output.view(-1, 1).squeeze(1) | |
netD = Discriminator(ngpu).to(device) | |
netD.apply(weights_init) | |
if opt.netD != '': | |
netD.load_state_dict(torch.load(opt.netD)) | |
print(netD) | |
criterion = nn.BCELoss() | |
fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) | |
real_label = 1 | |
fake_label = 0 | |
# setup optimizer | |
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) | |
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) | |
for epoch in range(opt.niter): | |
for i, data in enumerate(dataloader, 0): | |
############################ | |
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) | |
########################### | |
# train with real | |
netD.zero_grad() | |
real_cpu = data[0].to(device) | |
batch_size = real_cpu.size(0) | |
label = torch.full((batch_size,), real_label, device=device) | |
output = netD(real_cpu) | |
errD_real = criterion(output, label) | |
errD_real.backward() | |
D_x = output.mean().item() | |
# train with fake | |
noise = torch.randn(batch_size, nz, 1, 1, device=device) | |
fake = netG(noise) | |
label.fill_(fake_label) | |
output = netD(fake.detach()) | |
errD_fake = criterion(output, label) | |
errD_fake.backward() | |
D_G_z1 = output.mean().item() | |
errD = errD_real + errD_fake | |
optimizerD.step() | |
############################ | |
# (2) Update G network: maximize log(D(G(z))) | |
########################### | |
netG.zero_grad() | |
label.fill_(real_label) # fake labels are real for generator cost | |
output = netD(fake) | |
errG = criterion(output, label) | |
errG.backward() | |
D_G_z2 = output.mean().item() | |
optimizerG.step() | |
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' | |
% (epoch, opt.niter, i, len(dataloader), | |
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) | |
if i % 100 == 0: | |
vutils.save_image(real_cpu, | |
'%s/real_samples.png' % opt.outf, | |
normalize=True) | |
fake = netG(fixed_noise) | |
vutils.save_image(fake.detach(), | |
'%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch), | |
normalize=True) | |
# do checkpointing | |
torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) | |
torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) |
$ python main.py --cuda --dataset cifar10 --dataroot .
No comments:
Post a Comment