diff --git a/implementations/gan/gan.py b/implementations/gan/gan.py index d6f1d935..75286cb2 100644 --- a/implementations/gan/gan.py +++ b/implementations/gan/gan.py @@ -8,7 +8,6 @@ from torch.utils.data import DataLoader from torchvision import datasets -from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F @@ -122,11 +121,11 @@ def forward(self, img): for i, (imgs, _) in enumerate(dataloader): # Adversarial ground truths - valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) - fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) + valid = torch.Tensor(imgs.size(0), 1).fill_(1.0) + fake = torch.Tensor(imgs.size(0), 1).fill_(0.0) # Configure input - real_imgs = Variable(imgs.type(Tensor)) + real_imgs = imgs.type(torch.Tensor) # ----------------- # Train Generator @@ -135,7 +134,7 @@ def forward(self, img): optimizer_G.zero_grad() # Sample noise as generator input - z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) + z = torch.Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))) # Generate a batch of images gen_imgs = generator(z)