web123456

Deep Learning - Variable Autoencoder (VAE) generates MNIST handwritten digital pictures

  • import os
  • import torch
  • import as nn
  • import as F
  • import torchvision
  • from torchvision import transforms
  • from import save_image
  • # Configure GPU or CPU settings
  • device = ('cuda' if .is_available() else 'cpu')
  • # Create a directory to save the generated image
  • sample_dir = 'samples'
  • if not (sample_dir):
  • (sample_dir)
  • # Hyperparameter settings
  • image_size = 784 #Image size
  • h_dim = 400
  • z_dim = 20
  • num_epochs = 15 #15 loops
  • batch_size = 128 #The quantity of a batch
  • learning_rate = 1e-3 #Learning Rate
  • # Get the dataset
  • dataset = (root='./data',
  • train=True,
  • transform=(),
  • download=True)
  • # Data is loaded according to batch_size size and randomly disrupted
  • data_loader = (dataset=dataset,
  • batch_size=batch_size,
  • shuffle=True)
  • # VAE Model
  • class VAE():
  • def __init__(self, image_size=784, h_dim=400, z_dim=20):
  • super(VAE, self).__init__()
  • self.fc1 = (image_size, h_dim)
  • self.fc2 = (h_dim, z_dim)
  • self.fc3 = (h_dim, z_dim)
  • self.fc4 = (z_dim, h_dim)
  • self.fc5 = (h_dim, image_size)
  • # Encoding, learning the mean and variance of Gaussian distribution
  • def encode(self, x):
  • h = (self.fc1(x))
  • return self.fc2(h), self.fc3(h)
  • # Represent the mean of the Gaussian distribution with the variance parameter and generate the hidden variable z. If the x~N(mu, var*var) distribution, then (x-mu)/var=z~N(0, 1) distribution
  • def reparameterize(self, mu, log_var):
  • std = (log_var / 2)
  • eps = torch.randn_like(std)
  • return mu + eps * std
  • # Decode the hidden variable z
  • def decode(self, z):
  • h = (self.fc4(z))
  • return (self.fc5(h))
  • # Calculate the distribution parameters of the reconstructed value and hidden variable z
  • def forward(self, x):
  • mu, log_var = (x) # Learn the distribution of the hidden variable z from the original sample x, that is, learn to obey the Gaussian distribution mean and variance
  • z = (mu, log_var) # Represent the Gaussian distribution mean with the variance parameter, and generate the hidden variable z
  • x_reconst = (z) # Decode the hidden variable z, generate and refactor x’
  • return x_reconst, mu, log_var # Return the distribution parameters of reconstructed values ​​and hidden variables
  • # Construct VAE instance object
  • model = VAE().to(device)
  • print(model)
  • """VAE(
  • (fc1): Linear(in_features=784, out_features=400, bias=True)
  • (fc2): Linear(in_features=400, out_features=20, bias=True)
  • (fc3): Linear(in_features=400, out_features=20, bias=True)
  • (fc4): Linear(in_features=20, out_features=400, bias=True)
  • (fc5): Linear(in_features=400, out_features=784, bias=True)
  • )"""
  • # Select the optimizer and pass in the VAE model parameters and learning rate
  • optimizer = ((), lr=learning_rate)
  • # Start training for a total of 15 cycles
  • for epoch in range(num_epochs):
  • for i, (x, _) in enumerate(data_loader):
  • # Forward communication
  • x = (device).view(-1,image_size) # Change batch_size*1*28*28 --->batch_size*image_size where image_size=1*28*28=784
  • x_reconst, mu, log_var = model(x) # Conduct forward propagation calculation of the x input model of batch_size*748, reconstruct the value and the distribution parameters of the hidden variable z that obeys the Gaussian distribution (mean and variance)
  • # Calculate the reconstruction loss and KL divergence
  • # Refactoring Loss
  • reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
  • # KL divergence
  • kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
  • # Backpropagation and optimization
  • # Calculation error (reconstruction error and KL divergence value)
  • loss = reconst_loss + kl_div
  • # Clear the residual update parameter value in the previous step
  • optimizer.zero_grad()
  • # Error backpropagation, calculate parameter update value
  • ()
  • # Apply parameter update value to parameters of VAE model
  • ()
  • # Print the result value for each iteration
  • if (i + 1) % 10 == 0:
  • print("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
  • .format(epoch + 1, num_epochs, i + 1, len(data_loader), reconst_loss.item(), kl_div.item()))
  • with torch.no_grad():
  • # Save sampled values
  • # Generate random number z
  • z = (batch_size, z_dim).to(device) # z is batch_size * z_dim = 128*20
  • # Decode the random number z decode output
  • out = (z).view(-1, 1, 28, 28)
  • # Save the result value
  • save_image(out, (sample_dir, 'sampled-{}.png'.format(epoch + 1)))
  • # Save the refactored value
  • # Conduct forward propagation calculation of x input model of batch_size*748 to obtain the reconstruction value out
  • out, _, _ = model(x)
  • # Splice the input and output together to save the output batch_size*1*28* (28+28) =batch_size*1*28*56
  • x_concat = ([(-1, 1, 28, 28), (-1, 1, 28, 28)], dim=3)
  • save_image(x_concat, (sample_dir, 'reconst-{}.png'.format(epoch + 1)))