Quantum Patch GAN
In this tutorial, we demonstrate how to implement a quantum patch GAN introduced in Chapter xxx for the generation of hand-written digit of five. The whole pipeline includes following steps:
- Load and pre-process the dataset
- Build the classical discriminator
- Build the quantum generator
- Train the quantum patch GAN
- Visualize the generated images
We begin by importing required libraries:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import pennylane as qml
import math
Step 1: Dataset Preparation
We will use the Optical Recognition of Handwritten Digits dataset, where each data point represents an $8\times 8$ grayscale image. Let’s start by defining a custom dataset class to load and process the data.
class OptdigitsData(Dataset):
def __init__(self, data_path, label):
"""
Dataset class for Optical Recognition of Handwritten Digits.
"""
super().__init__()
self.data = []
with open(data_path, 'r') as f:
for line in f.readlines():
if int(line.strip().split(',')[-1]) == label:
# Normalize image pixel values from [0,16) to [0, 1)
image = [int(pixel)/16 for pixel in line.strip().split(',')[:-1]]
image = np.array(image, dtype=np.float32).reshape(8, 8)
self.data.append(image)
self.label = label
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return torch.from_numpy(self.data[index]), self.label
After defining the dataset class, we visualize a few examples to better understand the structure of the dataset.
def visualize_dataset(data_path):
"""
Visualizes the dataset by displaying examples for each digit label.
"""
plt.figure(figsize=(10, 5))
for i in range(10):
plt.subplot(1, 10, i + 1)
data = OptdigitsData(data_path, label=i)
plt.imshow(data[0][0], cmap='gray')
plt.title(f"Label: {i}")
plt.axis('off')
plt.tight_layout()
plt.show()
visualize_dataset('code/chapter5_qnn/optical+recognition+of+handwritten+digits/optdigits.tra')
Step 2: Building the Classical Discriminator
The discriminator is a classical neural network responsible for distinguishing real images from fake ones. It consists of fully connected layers with ReLU activations. The output is re-scaled into (0,1) via the Sigmoid function to be a reasonable indicator whether the input image is real or fake.
class ClassicalDiscriminator(nn.Module):
"""
A classical discriminator for classifying real and fake images.
"""
def __init__(self, input_shape):
super().__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(int(np.prod(input_shape)), 256),
nn.ReLU(),
nn.Dropout(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img)
Step 3: Defining the Quantum Generator
The quantum generator is responsible for generating patches of images using parameterized quantum circuits (PQC). Each patch is constructed using a quantum circuit with trainable parameters.
Parameterized Quantum Circuit (PQC)
The PQC applies layers of single-qubit rotation gates and entangling gates to the input qubits.
def PQC(params):
n_layer, n_qubit = params.shape[0], params.shape[1]
for i in range(n_layer):
for j in range(n_qubit):
qml.Rot(params[i, j, 0], params[i, j, 1], params[i, j, 2], wires=j)
# Control Z gates
for j in range(n_qubit - 1):
qml.CZ(wires=[j, j + 1])
Quantum Generator Circuit
This quantum generator circuit (sub-generator) converts the latent variable $\bm{z}$ into the latent quantum state $\ket{\bm{z}}$, applies the PQC, conducts partial measurements on the aucillary system $\mathcal{A}$, and returns the probabilities of each computational basis of the remaining system as the generated pixel values.
def QuantumGenerator(params, z=None, n_qubit_a=1):
n_qubit = params.shape[1]
# angle encoding of latent state z
for i in range(n_qubit):
qml.RY(z[i], wires=i)
PQC(params)
# partial measurement on the ancillary qubits
qml.measure(wires=n_qubit-1)
return qml.probs(wires=range(n_qubit-n_qubit_a))
Quantum Patch Generator
With the sub-generators, the quantum patch generator combines outputed patches from multiple sub-generators to generate the whole image.
class PatchQuantumGenerator(nn.Module):
"""
Combines patches generated by quantum circuits into full images.
"""
def __init__(self, qnode_generator, n_generator, n_qubit, n_qubit_a, n_layer):
super().__init__()
self.params_generator = nn.ParameterList([
nn.Parameter(torch.rand((n_layer, n_qubit, 3)), requires_grad=True) for _ in range(n_generator)
])
self.qnode_generator = qnode_generator
self.n_qubit_a = n_qubit_a
def forward(self, zs):
images = []
for z in zs:
patches = []
for params in self.params_generator:
patch = self.qnode_generator(params, z=z, n_qubit_a=self.n_qubit_a).float()
# post-processing: min-max scaling
patch = (patch - patch.min()) / (patch.max() - patch.min() + 1e-8)
patches.append(patch.unsqueeze(0))
patches = torch.cat(patches, dim=0)
images.append(patches.flatten().unsqueeze(0))
return torch.cat(images, dim=0)
Step 4: Training the Patch Quantum GAN
Initializing the Quantum GAN
We initialize the quantum generator and classical discriminator along with their optimizers.
# Hyperparameters
torch.manual_seed(0)
image_width = 8
image_height = 8
n_generator = 4
n_qubit_d = int(np.log2((image_width * image_height) // n_generator))
n_qubit_a = 1
n_qubit = n_qubit_d + n_qubit_a
n_layer = 6
# Quantum device
dev = qml.device("lightning.qubit", wires=n_qubit)
qnode_generator = qml.QNode(QuantumGenerator, dev)
# Initialize generator and discriminator
discriminator = ClassicalDiscriminator([image_height, image_width])
discriminator.train()
generator = PatchQuantumGenerator(qnode_generator, n_generator, n_qubit, n_qubit_a, n_layer)
generator.train()
# Optimizers
lr_generator = 0.3
lr_discriminator = 1e-2
opt_discriminator = optim.SGD(discriminator.parameters(), lr=lr_discriminator)
opt_generator = optim.SGD(generator.parameters(), lr=lr_generator)
# Construct dataset and dataloader
batch_size = 4
dataset = OptdigitsData('code/chapter5_qnn/optical+recognition+of+handwritten+digits/optdigits.tra', label=5)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
# Loss function
loss_fn = nn.BCELoss()
labels_real = torch.ones(batch_size, dtype=torch.float)
labels_fake = torch.zeros(batch_size, dtype=torch.float)
# Testing setup
n_test = 10
z_test = torch.rand(n_test, n_qubit) * math.pi
Training Loop
The GAN is trained using alternating updates for the discriminator and the generator. The generator learns to produce images that fool the discriminator. We save the generated images every epoch to record the behaviour of the generator during training.
n_epoch = 10
record = {}
for i in range(n_epoch):
for data, _ in dataloader:
zs = torch.rand(batch_size, n_qubit) * math.pi
image_fake = generator(zs)
# Training the discriminator
discriminator.zero_grad()
pred_fake = discriminator(image_fake.detach())
pred_real = discriminator(data)
loss_discriminator = loss_fn(pred_fake.squeeze(), labels_fake) + loss_fn(pred_real.squeeze(), labels_real)
loss_discriminator.backward()
opt_discriminator.step()
# Training the generator
generator.zero_grad()
pred_fake = discriminator(image_fake)
loss_generator = loss_fn(pred_fake.squeeze(), labels_real)
loss_generator.backward()
opt_generator.step()
print(f'The {i}-th epoch: discriminator loss={loss_discriminator: 0.3f}, generator loss={loss_generator: 0.3f}')
# test
generator.eval()
image_generated = generator(z_test).view(n_test, image_height, image_width).detach()
record[str(i)] = {
'loss_discriminator': loss_discriminator.item(),
'loss_generator': loss_generator.item(),
'image_generated': image_generated.numpy().tolist()
}
generator.train()
Step 5: Visualizing the Generated Images
After training, we can visualize the images generated by the quantum generator.
n_epochs_to_visualize = len(record) // 2
n_images_per_epoch = 10
fig, axes = plt.subplots(n_epochs_to_visualize, n_images_per_epoch, figsize=(n_images_per_epoch, n_epochs_to_visualize))
# Iterate through the recorded epochs and visualize generated images
for epoch_idx, (epoch, data) in enumerate(record.items()):
if epoch_idx % 2 == 1:
continue
images = np.array(data['image_generated'])
for img_idx in range(n_images_per_epoch):
ax = axes[epoch_idx // 2, img_idx]
ax.imshow(images[img_idx], cmap='gray')
ax.axis('off')
# Add epoch information to the title of each row
if img_idx == 0:
ax.set_title(f"Epoch {epoch}", fontsize=10)
plt.tight_layout()
plt.show()