GAN Project: Monet Paintings Generation Using CycleGAN¶
Brief Description of the Problem and Data¶
In this project, we explore a Generative Adversarial Network (GAN) approach—specifically CycleGAN—to generate Monet-style paintings from photographic images. The main challenge is learning an unsupervised mapping between two distinct image domains: Monet paintings and photos.
Dataset Source:
- Citation: Amy Jang, Ana Sofia Uzsoy, and Phil Culliton. "I’m Something of a Painter Myself." Kaggle GAN Getting Started Competition, 2020. Kaggle.
The dataset comprises:
- Monet images: 300 images of Monet paintings, each sized 256 x 256 pixels.
- Photo images: 7,000 real-world photos, each sized 256 x 256 pixels.
Both image sets are provided in JPEG format and represent distinct visual styles, offering significant variation in color, texture, and content.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import GroupNormalization
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.callbacks import Callback
import numpy as np
import re
import os
import random
from tensorflow.keras.applications import InceptionV3
import PIL
from PIL import Image
import zipfile
import io
import time
from tqdm.notebook import tqdm
try:
from kaggle_datasets import KaggleDatasets
except:
pass
# CONSTANTS
BATCH_SIZE = 32
IMAGE_SIZE = [256, 256]
EPOCHS = 50
STEPS_PER_EPOCH = 100
FID_INTERVAL = 5
# Create directories
# Create all required directories
dirs_to_create = [
'cache', # For dataset caching
'cache/monet', # Specific cache directories
'cache/photo', # Specific cache directories
'cyclegan_checkpoints',# For model checkpoints
'generated_samples', # For saving generated images
'logs' # For tensorboard logs if needed
]
# Create each directory
for directory in dirs_to_create:
os.makedirs(directory, exist_ok=True)
print(f"Directory {directory} is ready")
# Function to safely use cache paths
def get_cache_path(name):
"""Returns a valid cache path for the given name, ensuring the directory exists."""
path = os.path.join('cache', name)
os.makedirs(path, exist_ok=True)
return path
Directory cache is ready Directory cache/monet is ready Directory cache/photo is ready Directory cyclegan_checkpoints is ready Directory generated_samples is ready Directory logs is ready
# Hardware Accelaration
try:
# 1. Detect TPU hardware
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # Auto-detects TPU
# 2. Initialize TPU system
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
# 3. Create distributed training strategy
strategy = tf.distribute.TPUStrategy(tpu)
print(f"TPU detected: {tpu.cluster_spec().as_dict()['worker']}")
print(f"Number of TPU cores: {strategy.num_replicas_in_sync}")
except ValueError:
# Fallback to GPU/CPU
strategy = tf.distribute.get_strategy()
print(f"Using {strategy.__class__.__name__} (CPU/GPU)")
print(f"Number of replicas: {strategy.num_replicas_in_sync}")
AUTOTUNE = tf.data.experimental.AUTOTUNE
Using _DefaultDistributionStrategy (CPU/GPU) Number of replicas: 1
# Set Batch size if we have a TPU
# (NOTE BATCH_SIZE must be divisble by 8 for 8 core TPU)
BATCH_SIZE = BATCH_SIZE * strategy.num_replicas_in_sync
Exploratory Data Analysis (EDA): Inspecting, Visualizing, and Cleaning the Data¶
In this section, we perform initial inspections, visualize image samples from both Monet paintings and photos, and discuss data preparation steps.
Inspecting the Data¶
- Check for missing or corrupted images.
- Confirm the dimensions and formats of the images.
Data Visualization¶
Below, we visualize sample images from each domain to understand visual differences and domain characteristics.
Include visualizations here (e.g., matplotlib figures showing samples of images from each dataset)
Data Cleaning and Preparation¶
- Resized images to 256x256 pixels (if not already standardized).
- Verified image normalization (scaling pixel values to [0,1] range).
Based on the EDA, our plan includes:
- Ensuring consistent image dimensions.
- Implementing appropriate image augmentations if needed.
- Normalizing pixel values to facilitate stable training.
# Load and prepare data
# data_path = 'data'
# for kaggle nb
data_path = KaggleDatasets().get_gcs_path("gan-getting-started")
MONET_JPG= tf.io.gfile.glob(str(data_path + '/monet_jpg/*.jpg'))
PHOTO_JPG = tf.io.gfile.glob(str(data_path + '/photo_jpg/*.jpg'))
MONET_TFREC = tf.io.gfile.glob(str(data_path + '/monet_tfrec/*.tfrec'))
PHOTO_TFREC = tf.io.gfile.glob(str(data_path + '/photo_tfrec/*.tfrec'))
def decode_image(image):
"""
Decode JPEG: Converts raw bytes to a uint8 tensor.
Normalize: Scales pixel values from [0, 255] to [-1, 1] (standard for GANs).
Reshape: Forces images to 256x256x3 (CycleGAN’s default input size).
"""
image = tf.image.decode_jpeg(image, channels=3) # Decode JPEG bytes
image = (tf.cast(image, tf.float32) / 127.5) - 1 # Normalize to [-1, 1]
image = tf.reshape(image, [*IMAGE_SIZE, 3]) # Resize to 256x256
return image
def read_tfrecord(example):
"""
Parses a single TFRecord example to extract the image field (stored as bytes).
Passes the bytes to decode_image for preprocessing.
"""
tfrecord_format = {"image": tf.io.FixedLenFeature([], tf.string)}
example = tf.io.parse_single_example(example, tfrecord_format)
return decode_image(example['image'])
def load_dataset(filenames, is_tfrec=True, image_size=IMAGE_SIZE):
"""
Load images from either TFRecords or directory of JPG/PNG files.
Args:
filenames: List of file paths (TFRecords) or directory path (for images)
is_tfrec: Boolean flag indicating if input is TFRecords
image_size: Target size for resizing images
"""
if is_tfrec:
# Handle TFRecord files
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(
read_tfrecord,
num_parallel_calls=tf.data.AUTOTUNE
)
else:
# Handle JPG/PNG images from directory
def process_image(file_path):
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.resize(img, image_size)
img = (tf.cast(img, tf.float32) / 127.5) - 1 # Normalize to [-1, 1]
return img
dataset = tf.data.Dataset.list_files(filenames + "/*.jpg")
dataset = dataset.map(
process_image,
num_parallel_calls=tf.data.AUTOTUNE
)
return dataset
# Load datasets
monet_ds = load_dataset(MONET_TFREC).batch(1)
photo_ds = load_dataset(PHOTO_TFREC).batch(1)
# Batches scaled by strategy.num_replicas_in_sync (for TPU/GPU parallelism). prefetch(32) overlaps data loading with training to avoid bottlenecks
fast_photo_ds = load_dataset(PHOTO_TFREC).batch(32*strategy.num_replicas_in_sync).prefetch(32)
# Subset (take(1024)) for Fréchet Inception Distance (FID) calculation. Larger batches (32 * replicas) for efficient evaluation.
fid_photo_ds = load_dataset(PHOTO_TFREC).take(1024).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_monet_ds = load_dataset(MONET_TFREC).batch(32*strategy.num_replicas_in_sync).prefetch(32)
# Take one sample from each dataset
monet_sample = next(iter(monet_ds.shuffle(10)))
photo_sample = next(iter(photo_ds.shuffle(10)))
# Convert from [-1, 1] range to [0, 1] for visualization
monet_image = monet_sample[0].numpy() * 0.5 + 0.5 # [0] accesses first image in batch
photo_image = photo_sample[0].numpy() * 0.5 + 0.5
# Create subplots
plt.figure(figsize=(10, 5))
# Plot Monet painting
plt.subplot(1, 2, 1)
plt.imshow(monet_image)
plt.title("Monet Style Example")
plt.axis('off')
# Plot Photo
plt.subplot(1, 2, 2)
plt.imshow(photo_image)
plt.title("Real Photo Example")
plt.axis('off')
plt.tight_layout()
plt.show()
print("Monet image shape:", monet_image.shape)
print("Photo image shape:", photo_image.shape)
Monet image shape: (256, 256, 3) Photo image shape: (256, 256, 3)
def get_gan_dataset(monet_files, photo_files, repeat=True, shuffle=True, batch_size=1):
# Re-load raw datasets fully
monet_ds = load_dataset(monet_files)
photo_ds = load_dataset(photo_files)
# Shuffle
if shuffle:
monet_ds = monet_ds.shuffle(2048)
photo_ds = photo_ds.shuffle(2048)
# Batch and optimize
monet_ds = monet_ds.batch(batch_size, drop_remainder=True)
photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
# Use cache
# Use the helper function for cache paths
monet_ds = monet_ds.cache(get_cache_path('monet'))
photo_ds = photo_ds.cache(get_cache_path('photo'))
# Repeat indefinitely for epochs
if repeat:
monet_ds = monet_ds.repeat()
photo_ds = photo_ds.repeat()
monet_ds = monet_ds.prefetch(AUTOTUNE)
photo_ds = photo_ds.prefetch(AUTOTUNE)
# Pair Monet and Photo batches
gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
return gan_ds
training_dataset = get_gan_dataset(
MONET_TFREC, # path to Monet paintings
PHOTO_TFREC, # path to real photos (domain B)
repeat=True, # loop dataset for multi-epoch training
shuffle=True, # shuffle data
batch_size=BATCH_SIZE
)
def build_inception_feature_extractor():
"""
Creates a modified InceptionV3 model for FID feature extraction.
Returns:
tf.keras.Model: Feature extractor using InceptionV3's "mixed9" layer outputs.
"""
# Load pre-trained InceptionV3 without classification head
inception = tf.keras.applications.InceptionV3(
include_top=False, # Remove final classification layer
pooling="avg", # Add global average pooling after last conv layer
input_shape=(256, 256, 3) # Match CycleGAN's image size
)
# Extract intermediate features from "mixed9" layer
# Why "mixed9"? It captures high-level features before final pooling
mix9 = inception.get_layer("mixed9").output # Shape: (None, 8, 8, 2048)
# Additional pooling to reduce spatial dimensions
features = layers.GlobalAveragePooling2D()(mix9) # Shape: (None, 2048)
# Build final feature extraction model
return tf.keras.Model(inputs=inception.input, outputs=features)
def calculate_activation_statistics(dataset, fid_model):
"""
Computes mean and covariance matrix of feature vectors from a dataset.
Args:
dataset (tf.data.Dataset): Batched dataset of images (shape: [None, 256, 256, 3])
fid_model (tf.keras.Model): Feature extractor model
Returns:
tuple: (mu, sigma) - Mean vector and covariance matrix
"""
# Initialize lists to collect activations
all_activations = []
# Process dataset in batches
for batch in dataset:
# Extract features for current batch
act = fid_model(batch)
all_activations.append(act)
# Concatenate all activations
act_matrix = tf.concat(all_activations, axis=0)
# Calculate mean and covariance
mu = tf.reduce_mean(act_matrix, axis=0)
sigma = tfp.stats.covariance(act_matrix, sample_axis=0) # Requires tensorflow-probability
return mu, sigma
# -------------------- Execution --------------------
with strategy.scope(): # TPU/GPU integration
# 1. Initialize feature extractor
inception_model = build_inception_feature_extractor()
inception_model.trainable = False
# 2. Precompute real image statistics
# Ensure fid_monet_ds is properly batched
fid_monet_ds = load_dataset(MONET_TFREC).batch(32).prefetch(AUTOTUNE)
# Calculate statistics
myFID_mu_real, myFID_sigma_real = calculate_activation_statistics(
fid_monet_ds, # receives proper image tensors
inception_model
)
# 3. Initialize FID tracking list
fids = []
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5 87910968/87910968 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
with strategy.scope():
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2):
"""Computes the Fréchet Distance between two multivariate Gaussians."""
# Squared L2 norm of mean difference
diff = mu1 - mu2
mean_term = tf.reduce_sum(diff**2) # ||μ₁ - μ₂||²
# Covariance term: Tr(Σ₁ + Σ₂ - 2√(Σ₁Σ₂))
cov_product = tf.matmul(sigma1, sigma2)
cov_product = tf.cast(cov_product, tf.complex64) # For sqrtm
covmean = tf.linalg.sqrtm(cov_product)
covmean = tf.math.real(covmean) # Cast back to float32
covmean = tf.cast(covmean, tf.float32)
# Avoid NaN gradients by ensuring covmean is finite
covmean = tf.where(
tf.math.is_nan(covmean),
tf.zeros_like(covmean),
covmean
)
# Compute trace terms
tr_covmean = tf.linalg.trace(covmean)
trace_term = (
tf.linalg.trace(sigma1) +
tf.linalg.trace(sigma2) -
2 * tr_covmean
)
fid = mean_term + trace_term
return fid
def compute_fid(generator, inception_model, real_mu, real_sigma, dataset):
"""Computes FID between generated and real images."""
# Define a tf.function to generate images
@tf.function
def generate_images(images):
return generator(images, training=False)
# Collect all generated activations
all_activations = []
for batch in dataset:
# Generate images
generated_images = generate_images(batch)
# Extract features
activations = inception_model(generated_images)
all_activations.append(activations)
# Concatenate activations
gen_activations = tf.concat(all_activations, axis=0)
# Compute generated statistics
gen_mu = tf.reduce_mean(gen_activations, axis=0)
gen_sigma = tfp.stats.covariance(gen_activations)
# Calculate FID
fid_value = calculate_frechet_distance(
gen_mu, gen_sigma,
real_mu, real_sigma
)
return fid_value
def down_sample(filters, size, apply_instancenorm=True):
initializer = tf.random_normal_initializer(0., 0.02)
result = keras.Sequential()
result.add(layers.Conv2D(filters, size, strides=2, padding='same',
kernel_initializer=initializer, use_bias=False))
if apply_instancenorm:
result.add(GroupNormalization(groups=-1))
result.add(layers.LeakyReLU())
return result
def up_sample(filters, size, apply_dropout=False):
initializer = tf.random_normal_initializer(0., 0.02)
result = keras.Sequential()
result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
kernel_initializer=initializer, use_bias=False))
result.add(GroupNormalization(groups=-1))
if apply_dropout:
result.add(layers.Dropout(0.5))
result.add(layers.ReLU())
return result
Model Architecture¶
We selected the CycleGAN architecture for this problem, primarily because:
- Unsupervised Learning: CycleGAN allows image-to-image translation without paired datasets, making it ideal for artistic style transfer scenarios.
- Cycle Consistency Loss: This ensures the model preserves content details while converting styles, essential for generating realistic Monet-style images from photos.
Architecture Overview:¶
- Generators: Utilize ResNet-based generators for stable training and effective feature extraction.
Input (256x256x3)
│
├─Downsample 64 → 128x128x64
├─Downsample 128 → 64x64x128
├─Downsample 256 → 32x32x256
├─Downsample 512 → 16x16x512
├─Downsample 512 → 8x8x512
├─Downsample 512 → 4x4x512
├─Downsample 512 → 2x2x512
└─Downsample 512 → 1x1x512 (bottleneck)
│
├─Upsample 512 → 2x2x512 (with skip from 2x2x512)
├─Upsample 512 → 4x4x512 (with skip from 4x4x512)
├─Upsample 512 → 8x8x512 (with skip from 8x8x512)
├─Upsample 512 → 16x16x512 (with skip from 16x16x512)
├─Upsample 256 → 32x32x256 (with skip from 32x32x256)
├─Upsample 128 → 64x64x128 (with skip from 64x64x128)
├─Upsample 64 → 128x128x64 (with skip from 128x128x64)
│
└─Output (256x256x3)
- Discriminators: Employ PatchGAN discriminators to classify local image patches, which helps the model capture detailed textures and patterns characteristic of Monet paintings.
Input Image (256x256x3)
│
├─Downsample 64 → 128x128x64
├─Downsample 128 → 64x64x128
├─Downsample 256 → 32x32x256
│
├─ZeroPad2D → 34x34x256
├─Conv2D 512 → 31x31x512
├─InstanceNorm + LeakyReLU
└─ZeroPad2D → 33x33x512 (Base Discriminator Output)
│
├─DHead:
├─Conv2D 1 → 30x30x1 (Patch Predictions)
def Generator():
'''
1. Downsampling Path (Encoder)
Purpose: Compresses the input image into a low-dimensional bottleneck.
Layers:
8 down_sample blocks with increasing filters (64 → 512).
Each block reduces spatial resolution by half (stride=2).
First block omits instance normalization to preserve low-level details.
Output: 1x1x512 bottleneck tensor.
2. Upsampling Path (Decoder)
Purpose: Reconstructs the image from the bottleneck to the target domain.
Layers:
7 up_sample blocks with decreasing filters (512 → 64).
Transposed convolutions (stride=2) double spatial resolution.
Dropout (50%) in early layers to prevent overfitting.
Skip Connections: Concatenate features from the encoder to the decoder (U-Net structure).
3. Final Output Layer
Conv2DTranspose:
Output channels: 3 (RGB).
tanh activation: Normalizes outputs to [-1, 1], matching input normalization.
'''
inputs = layers.Input(shape=[256, 256, 3])
# Downsampling path
down_stack = [
down_sample(64, 4, apply_instancenorm=False), # 256x256 → 128x128
down_sample(128, 4), # 128x128 → 64x64
down_sample(256, 4), # 64x64 → 32x32
down_sample(512, 4), # 32x32 → 16x16
down_sample(512, 4), # 16x16 → 8x8
down_sample(512, 4), # 8x8 → 4x4
down_sample(512, 4), # 4x4 → 2x2
down_sample(512, 4), # 2x2 → 1x1 (bottleneck)
]
# Upsampling path
up_stack = [
up_sample(512, 4, apply_dropout=True), # 1x1 → 2x2
up_sample(512, 4, apply_dropout=True), # 2x2 → 4x4
up_sample(512, 4, apply_dropout=True), # 4x4 → 8x8
up_sample(512, 4), # 8x8 → 16x16
up_sample(256, 4), # 16x16 → 32x32
up_sample(128, 4), # 32x32 → 64x64
up_sample(64, 4), # 64x64 → 128x128
]
# Final output layer
initializer = tf.random_normal_initializer(0., 0.02)
last = layers.Conv2DTranspose(
3, 4, strides=2, padding='same',
kernel_initializer=initializer, activation='tanh'
) # 128x128 → 256x256
x = inputs
skips = []
# Downsampling
for down in down_stack:
x = down(x)
skips.append(x)
skips = reversed(skips[:-1]) # Omit the bottleneck layer
# Upsampling with skip connections
for up, skip in zip(up_stack, skips):
x = up(x)
x = layers.Concatenate()([x, skip])
x = last(x)
return keras.Model(inputs=inputs, outputs=x)
def Discriminator():
'''
Discriminator Architecture (Base)
Purpose: Feature extractor for real/fake classification
Structure:
3 Downsample Blocks: Reduce resolution while increasing filters (64→256)
Zero Padding: Expands spatial dimensions for subsequent convolutions
Final Conv Layer: 512 filters with stride 1 (no resolution change)
Output Shape: 33x33x512 feature maps
2. DHead (Decision Head)
Purpose: Final classification layer for adversarial loss
Structure:
Input: 33x33x512 features from base discriminator
Conv2D(1): Reduces to 30x30x1 "patch" predictions
No Activation: Raw logits for different loss functions
3. Design Choices
Separate Heads: Allows using different loss functions:
dHead1: Binary cross-entropy (BCE)
dHead2: Hinge loss
Instance Normalization: Stabilizes training by normalizing features per-image
Zero Padding: Preserves spatial dimensions after convolutions
LeakyReLU: (α=0.2) prevents dead neurons in discriminator
'''
initializer = tf.random_normal_initializer(0., 0.02)
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
inp = layers.Input(shape=[256, 256, 3], name='input_image')
x = inp
# Downsampling Path
down1 = down_sample(64, 4, False)(x) # 256x256 → 128x128
down2 = down_sample(128, 4)(down1) # 128x128 → 64x64
down3 = down_sample(256, 4)(down2) # 64x64 → 32x32
# Final Layers
zero_pad1 = layers.ZeroPadding2D()(down3) # 32x32 → 34x34
conv = layers.Conv2D(512, 4, strides=1,
kernel_initializer=initializer,
use_bias=False)(zero_pad1) # 34x34 → 31x31
norm1 = GroupNormalization(groups=-1, gamma_initializer=gamma_init)(conv)
leaky_relu = layers.LeakyReLU()(norm1)
zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # 31x31 → 33x33
return keras.Model(inputs=inp, outputs=zero_pad2)
def DHead():
initializer = tf.random_normal_initializer(0., 0.02)
inp = layers.Input(shape=[33, 33, 512], name='input_image')
x = inp
last = layers.Conv2D(1, 4, strides=1,
kernel_initializer=initializer)(x) # 33x33 → 30x30
return keras.Model(inputs=inp, outputs=last)
with strategy.scope():
def DiffAugment(x, policy='', channels_first=False):
if policy:
if channels_first:
x = tf.transpose(x, [0, 2, 3, 1])
for p in policy.split(','):
for f in AUGMENT_FNS[p]:
x = f(x)
if channels_first:
x = tf.transpose(x, [0, 3, 1, 2])
return x
def rand_brightness(x):
magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5
x = x + magnitude
return x
def rand_saturation(x):
magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2
x_mean = tf.reduce_sum(x, axis=3, keepdims=True) * 0.3333333333333333333
x = (x - x_mean) * magnitude + x_mean
return x
def rand_contrast(x):
magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5
x_mean = tf.reduce_sum(x, axis=[1, 2, 3], keepdims=True) * 5.086e-6
x = (x - x_mean) * magnitude + x_mean
return x
def rand_translation(x, ratio=0.125):
batch_size = tf.shape(x)[0]
image_size = tf.shape(x)[1:3]
shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)
translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)
grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1)
grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1)
x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)
x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3])
return x
def rand_cutout(x, ratio=0.5):
batch_size = tf.shape(x)[0]
image_size = tf.shape(x)[1:3]
cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32)
offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32)
grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij')
cutout_grid = tf.stack([grid_batch, grid_x + offset_x - cutout_size[0] // 2, grid_y + offset_y - cutout_size[1] // 2], axis=-1)
mask_shape = tf.stack([batch_size, image_size[0], image_size[1]])
cutout_grid = tf.maximum(cutout_grid, 0)
cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3]))
mask = tf.maximum(1 - tf.scatter_nd(cutout_grid, tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32), mask_shape), 0)
x = x * tf.expand_dims(mask, axis=3)
return x
AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
def aug_fn(image):
return DiffAugment(image,"color,translation,cutout")
class CycleGan(keras.Model):
def __init__(self, m_gen, p_gen, m_disc, p_disc, dhead1=None, dhead2=None, lambda_cycle=3, lambda_id=3):
super().__init__()
self.m_gen = m_gen # Monet generator (photos → paintings)
self.p_gen = p_gen # Photo generator (paintings → photos)
self.m_disc = m_disc # Monet discriminator (base)
self.p_disc = p_disc # Photo discriminator (full)
self.dhead1 = dhead1 # First discriminator head
self.dhead2 = dhead2 # Second discriminator head (can be None)
self.lambda_cycle = lambda_cycle # Cycle consistency weight
self.lambda_id = lambda_id # Identity loss weight
def compile(self, m_gen_opt, p_gen_opt, m_disc_opt, p_disc_opt,
gen_loss_fn1, gen_loss_fn2, disc_loss_fn1, disc_loss_fn2,
cycle_loss_fn, identity_loss_fn, aug_fn):
super().compile()
# Optimizers
self.m_gen_opt = m_gen_opt
self.p_gen_opt = p_gen_opt
self.m_disc_opt = m_disc_opt
self.p_disc_opt = p_disc_opt
# Loss Functions
self.gen_loss_fn1 = gen_loss_fn1 # e.g., BCE
self.gen_loss_fn2 = gen_loss_fn2 # e.g., Hinge
self.disc_loss_fn1 = disc_loss_fn1
self.disc_loss_fn2 = disc_loss_fn2
self.cycle_loss_fn = cycle_loss_fn
self.identity_loss_fn = identity_loss_fn
# Augmentation
self.aug_fn = aug_fn # DiffAugment policy
def augment_batch(self, real_images, fake_images):
"""Apply data augmentation to both real and fake images."""
# Concatenate images for batched augmentation (more efficient)
combined = tf.concat([real_images, fake_images], axis=0)
# Apply augmentation
augmented = self.aug_fn(combined)
# Split back into real and fake
batch_size = tf.shape(real_images)[0]
aug_real = augmented[:batch_size]
aug_fake = augmented[batch_size:]
return aug_real, aug_fake
def train_step(self, batch_data):
real_monet, real_photo = batch_data
batch_size = tf.shape(real_monet)[0]
with tf.GradientTape(persistent=True) as tape:
# Forward cycle: photo → monet → photo
fake_monet = self.m_gen(real_photo, training=True)
cycled_photo = self.p_gen(fake_monet, training=True)
# Backward cycle: monet → photo → monet
fake_photo = self.p_gen(real_monet, training=True)
cycled_monet = self.m_gen(fake_photo, training=True)
# Identity mapping
same_monet = self.m_gen(real_monet, training=True)
same_photo = self.p_gen(real_photo, training=True)
# Apply augmentation if provided
aug_real_monet, aug_fake_monet = self.augment_batch(real_monet, fake_monet)
aug_real_photo, aug_fake_photo = self.augment_batch(real_photo, fake_photo)
# Monet discriminator
disc_real_monet_features = self.m_disc(aug_real_monet, training=True)
disc_fake_monet_features = self.m_disc(aug_fake_monet, training=True)
# Initialize loss values
monet_gen_loss = 0
monet_disc_loss = 0
monet_gen_loss2 = 0
monet_disc_loss2 = 0
# Use discriminator head if available
if self.dhead1 is not None:
disc_real_monet = self.dhead1(disc_real_monet_features, training=True)
disc_fake_monet = self.dhead1(disc_fake_monet_features, training=True)
monet_gen_loss = self.gen_loss_fn1(disc_fake_monet)
monet_disc_loss = self.disc_loss_fn1(disc_real_monet, disc_fake_monet)
# Second head (optional)
if self.dhead2 is not None:
disc_real_monet2 = self.dhead2(disc_real_monet_features, training=True)
disc_fake_monet2 = self.dhead2(disc_fake_monet_features, training=True)
monet_gen_loss2 = self.gen_loss_fn2(disc_fake_monet2)
monet_disc_loss2 = self.disc_loss_fn2(disc_real_monet2, disc_fake_monet2)
else:
# Use features directly (patch discriminator)
disc_real_monet = disc_real_monet_features
disc_fake_monet = disc_fake_monet_features
monet_gen_loss = self.gen_loss_fn1(disc_real_monet)
monet_disc_loss = self.disc_loss_fn1(disc_real_monet, disc_fake_monet)
# Photo discriminator
disc_real_photo = self.p_disc(aug_real_photo, training=True)
disc_fake_photo = self.p_disc(aug_fake_photo, training=True)
photo_gen_loss = self.gen_loss_fn1(disc_fake_photo)
photo_disc_loss = self.disc_loss_fn1(disc_real_photo, disc_fake_photo)
# Cycle consistency loss
cycle_loss = (
self.cycle_loss_fn(real_monet, cycled_monet) +
self.cycle_loss_fn(real_photo, cycled_photo)
) * self.lambda_cycle
# Identity loss
identity_loss = (
self.identity_loss_fn(real_monet, same_monet) +
self.identity_loss_fn(real_photo, same_photo)
) * self.lambda_id
# Total losses
total_monet_gen_loss = monet_gen_loss + monet_gen_loss2 + cycle_loss + identity_loss
total_photo_gen_loss = photo_gen_loss + cycle_loss + identity_loss
total_monet_disc_loss = monet_disc_loss + monet_disc_loss2
# Calculate gradients
monet_gen_grads = tape.gradient(total_monet_gen_loss, self.m_gen.trainable_variables)
photo_gen_grads = tape.gradient(total_photo_gen_loss, self.p_gen.trainable_variables)
# Apply generator gradients directly
self.m_gen_opt.apply_gradients(zip(monet_gen_grads, self.m_gen.trainable_variables))
self.p_gen_opt.apply_gradients(zip(photo_gen_grads, self.p_gen.trainable_variables))
# Handle discriminator (complete networks at once instead of separating heads)
# This avoids unnecessary computation in graph mode
all_m_disc_vars = self.m_disc.trainable_variables
if self.dhead1 is not None:
all_m_disc_vars += self.dhead1.trainable_variables
if self.dhead2 is not None:
all_m_disc_vars += self.dhead2.trainable_variables
monet_disc_grads = tape.gradient(total_monet_disc_loss, all_m_disc_vars)
photo_disc_grads = tape.gradient(photo_disc_loss, self.p_disc.trainable_variables)
self.m_disc_opt.apply_gradients(zip(monet_disc_grads, all_m_disc_vars))
self.p_disc_opt.apply_gradients(zip(photo_disc_grads, self.p_disc.trainable_variables))
# Return metrics
return {
"monet_gen_loss": monet_gen_loss,
"photo_gen_loss": photo_gen_loss,
"monet_disc_loss": monet_disc_loss,
"photo_disc_loss": photo_disc_loss,
"cycle_loss": cycle_loss,
"identity_loss": identity_loss
}
Building and Training the CycleGAN¶
Here, we detail the training setup, including loss functions, optimizers, epochs, and computational resources:
- Cycle Consistency Loss
- Adversarial Loss
- Identity Loss (optional)
- Optimizers: Adam optimizer with learning rate scheduler
- Training duration: Number of epochs and batch sizes
with strategy.scope():
# ========================
# 1. Initialize Models
# ========================
monet_generator = Generator() # Photos → Monet
photo_generator = Generator() # Monet → Photos
monet_discriminator = Discriminator()
photo_discriminator = Discriminator()
dhead1 = DHead() # For BCE loss
dhead2 = DHead() # For Hinge loss
# ========================
# 2. Define Optimizers
# ========================
monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
# ========================
# 3. Define Loss Functions
# ========================
def generator_loss1(generated):
return tf.reduce_mean(-generated) # Hinge loss`
def generator_loss2(generated):
return tf.keras.losses.BinaryCrossentropy(from_logits=True)(
tf.ones_like(generated), generated)
def discriminator_loss1(real, generated):
real_loss = tf.reduce_mean(tf.minimum(0., -1. + real))
fake_loss = tf.reduce_mean(tf.minimum(0., -1. - generated))
return -tf.reduce_mean(real_loss + fake_loss)
def discriminator_loss2(real, generated):
real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
tf.ones_like(real), real)
fake_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
tf.zeros_like(generated), generated)
return 0.5 * (real_loss + fake_loss)
def cycle_loss_fn(real, cycled):
return tf.reduce_mean(tf.abs(real - cycled))
def identity_loss_fn(real, same):
return tf.reduce_mean(tf.abs(real - same))
# ========================
# 4. Compile CycleGAN
# ========================
cycle_gan = CycleGan(
monet_generator,
photo_generator,
monet_discriminator,
photo_discriminator,
dhead1,
dhead2,
lambda_cycle=10, # Stronger cycle consistency
lambda_id=0.5 # Weaker identity loss
)
cycle_gan.compile(
m_gen_opt=monet_generator_optimizer,
p_gen_opt=photo_generator_optimizer,
m_disc_opt=monet_discriminator_optimizer,
p_disc_opt=photo_discriminator_optimizer,
gen_loss_fn1=generator_loss1,
gen_loss_fn2=generator_loss2,
disc_loss_fn1=discriminator_loss1,
disc_loss_fn2=discriminator_loss2,
cycle_loss_fn=cycle_loss_fn,
identity_loss_fn=identity_loss_fn,
aug_fn=aug_fn
)
# ========================
# 4.5 Monitoring
# ========================
class CycleGANMonitor(Callback):
def __init__(self, sample_photo, sample_monet, monet_generator, photo_generator, epoch_interval=5):
"""
Args:
sample_photo: Batch of sample photos (normalized to [-1, 1])
sample_monet: Batch of sample Monet paintings (normalized to [-1, 1])
monet_generator: Generator that converts photos to Monet style
photo_generator: Generator that converts Monet to photos
epoch_interval: How often to generate samples (in epochs)
"""
self.sample_photo = sample_photo
self.sample_monet = sample_monet
self.monet_generator = monet_generator
self.photo_generator = photo_generator
self.epoch_interval = epoch_interval
def _denormalize(self, image):
"""Convert from [-1, 1] range to [0, 1] for visualization"""
return (image * 0.5) + 0.5
@tf.function
def generate_predictions(self, images, generator):
return generator(images, training=False)
def _plot_predictions(self, epoch=None):
# Generate predictions with tf.function
monet_pred = self.generate_predictions(self.sample_photo, self.monet_generator)
photo_pred = self.generate_predictions(self.sample_monet, self.photo_generator)
plt.figure(figsize=(18, 8))
# Photo → Monet translations
for i in range(min(3, len(self.sample_photo))): # Show first 3 samples
plt.subplot(2, 3, i+1)
plt.imshow(self._denormalize(self.sample_photo[i]))
plt.title(f"Input Photo {i+1}")
plt.axis("off")
plt.subplot(2, 3, i+4)
plt.imshow(self._denormalize(monet_pred[i]))
plt.title(f"Generated Monet {i+1}" + (f" (Epoch {epoch})" if epoch else ""))
plt.axis("off")
plt.tight_layout()
plt.show()
def on_epoch_end(self, epoch, logs=None):
"""Generate samples at specified intervals"""
if (epoch+1) % self.epoch_interval == 0: # +1 to avoid epoch 0
self._plot_predictions(epoch+1)
def on_train_end(self, logs=None):
"""Final visualization after training"""
self._plot_predictions()
# Prepare sample images
sample_photo = next(iter(photo_ds.take(1))) # Get 1 batch of photos
sample_monet = next(iter(monet_ds.take(1))) # Get 1 batch of Monet paintings
# Create callback
viz_callback = CycleGANMonitor(
sample_photo=sample_photo,
sample_monet=sample_monet,
monet_generator=monet_generator,
photo_generator=photo_generator,
epoch_interval=5 # Generate samples every 5 epochs
)
# ========================
# 5. Training Loop
# ========================
fids = [] # Track FID scores during training
for epoch in range(1, EPOCHS+1):
print(f"Epoch {epoch}/{EPOCHS}")
# Train for one epoch
history = cycle_gan.fit(
training_dataset,
epochs=1,
steps_per_epoch=STEPS_PER_EPOCH,
callbacks=[viz_callback], # Use the visualization callback
verbose=1
)
# Periodic Evaluation
if epoch % FID_INTERVAL == 0:
try:
# Calculate FID
fid_score = compute_fid(
monet_generator,
inception_model,
myFID_mu_real,
myFID_sigma_real,
fid_photo_ds.take(64) # Use subset for faster evaluation
)
fids.append(fid_score.numpy())
print(f"FID after epoch {epoch}: {fid_score:.2f}")
# Save checkpoint
checkpoint_dir = f"./cyclegan_checkpoints/epoch_{epoch}"
checkpoint = tf.train.Checkpoint(
monet_generator=monet_generator,
photo_generator=photo_generator,
monet_generator_optimizer=monet_generator_optimizer,
photo_generator_optimizer=photo_generator_optimizer
)
checkpoint.save(checkpoint_dir)
print(f"Checkpoint saved at epoch {epoch}")
except Exception as e:
print(f"Error during FID calculation: {e}")
continue
# Learning rate decay after halfway
if epoch == EPOCHS // 2:
print("Reducing learning rate by factor of 10")
# Reduce learning rates
monet_generator_optimizer.learning_rate = monet_generator_optimizer.learning_rate * 0.1
photo_generator_optimizer.learning_rate = photo_generator_optimizer.learning_rate * 0.1
monet_discriminator_optimizer.learning_rate = monet_discriminator_optimizer.learning_rate * 0.1
photo_discriminator_optimizer.learning_rate = photo_discriminator_optimizer.learning_rate * 0.1
Epoch 1/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 264s 2s/step - cycle_loss: 4.2368 - identity_loss: 0.2133 - monet_disc_loss: 1.9284 - monet_gen_loss: 0.0400 - photo_disc_loss: 1.9992 - photo_gen_loss: -0.0053
Epoch 2/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 2.7125 - identity_loss: 0.1677 - monet_disc_loss: 1.8535 - monet_gen_loss: 0.0956 - photo_disc_loss: 1.9968 - photo_gen_loss: -0.0091
Epoch 3/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 2.2031 - identity_loss: 0.1707 - monet_disc_loss: 1.9045 - monet_gen_loss: 0.0614 - photo_disc_loss: 1.9953 - photo_gen_loss: -0.0135
Epoch 4/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 2.0119 - identity_loss: 0.1814 - monet_disc_loss: 1.9373 - monet_gen_loss: 0.0678 - photo_disc_loss: 1.9938 - photo_gen_loss: -0.0174
Epoch 5/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.9083 - identity_loss: 0.1691 - monet_disc_loss: 1.9500 - monet_gen_loss: 0.0609 - photo_disc_loss: 1.9911 - photo_gen_loss: -0.0212
FID after epoch 5: 16.71 Checkpoint saved at epoch 5 Epoch 6/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.8526 - identity_loss: 0.1543 - monet_disc_loss: 1.9486 - monet_gen_loss: 0.0195 - photo_disc_loss: 1.9869 - photo_gen_loss: -0.0248
Epoch 7/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.7730 - identity_loss: 0.1360 - monet_disc_loss: 1.9646 - monet_gen_loss: 0.0182 - photo_disc_loss: 1.9858 - photo_gen_loss: -0.0299
Epoch 8/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.7071 - identity_loss: 0.1201 - monet_disc_loss: 1.9587 - monet_gen_loss: 0.0287 - photo_disc_loss: 1.9818 - photo_gen_loss: -0.0325
Epoch 9/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.6600 - identity_loss: 0.1137 - monet_disc_loss: 1.9580 - monet_gen_loss: 0.0231 - photo_disc_loss: 1.9775 - photo_gen_loss: -0.0353
Epoch 10/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.5796 - identity_loss: 0.0990 - monet_disc_loss: 1.9690 - monet_gen_loss: 0.0013 - photo_disc_loss: 1.9743 - photo_gen_loss: -0.0389
FID after epoch 10: 16.70 Checkpoint saved at epoch 10 Epoch 11/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.5617 - identity_loss: 0.0966 - monet_disc_loss: 1.7614 - monet_gen_loss: 0.1003 - photo_disc_loss: 1.9690 - photo_gen_loss: -0.0407
Epoch 12/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.5055 - identity_loss: 0.1001 - monet_disc_loss: 1.7904 - monet_gen_loss: 0.1000 - photo_disc_loss: 1.9641 - photo_gen_loss: -0.0427
Epoch 13/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.4833 - identity_loss: 0.0965 - monet_disc_loss: 1.6997 - monet_gen_loss: 0.1375 - photo_disc_loss: 1.9584 - photo_gen_loss: -0.0439
Epoch 14/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.4315 - identity_loss: 0.0896 - monet_disc_loss: 1.8767 - monet_gen_loss: 0.0828 - photo_disc_loss: 1.9511 - photo_gen_loss: -0.0437
Epoch 15/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.4219 - identity_loss: 0.0884 - monet_disc_loss: 1.7719 - monet_gen_loss: 0.0870 - photo_disc_loss: 1.9478 - photo_gen_loss: -0.0466
FID after epoch 15: 13.20 Checkpoint saved at epoch 15 Epoch 16/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.4021 - identity_loss: 0.0877 - monet_disc_loss: 1.8049 - monet_gen_loss: 0.0964 - photo_disc_loss: 1.9465 - photo_gen_loss: -0.0502
Epoch 17/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.3972 - identity_loss: 0.0883 - monet_disc_loss: 1.7924 - monet_gen_loss: 0.1015 - photo_disc_loss: 1.9407 - photo_gen_loss: -0.0509
Epoch 18/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.3660 - identity_loss: 0.0839 - monet_disc_loss: 1.7914 - monet_gen_loss: 0.1003 - photo_disc_loss: 1.9401 - photo_gen_loss: -0.0550
Epoch 19/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.3485 - identity_loss: 0.0808 - monet_disc_loss: 1.8488 - monet_gen_loss: 0.0876 - photo_disc_loss: 1.9319 - photo_gen_loss: -0.0537
Epoch 20/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.3210 - identity_loss: 0.0785 - monet_disc_loss: 1.8822 - monet_gen_loss: 0.0548 - photo_disc_loss: 1.9278 - photo_gen_loss: -0.0554
FID after epoch 20: 12.65 Checkpoint saved at epoch 20 Epoch 21/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.3120 - identity_loss: 0.0772 - monet_disc_loss: 1.8385 - monet_gen_loss: 0.0617 - photo_disc_loss: 1.9252 - photo_gen_loss: -0.0577
Epoch 22/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2865 - identity_loss: 0.0777 - monet_disc_loss: 1.8299 - monet_gen_loss: 0.0877 - photo_disc_loss: 1.9303 - photo_gen_loss: -0.0649
Epoch 23/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2622 - identity_loss: 0.0751 - monet_disc_loss: 1.8757 - monet_gen_loss: 0.0631 - photo_disc_loss: 1.9340 - photo_gen_loss: -0.0689
Epoch 24/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2536 - identity_loss: 0.0743 - monet_disc_loss: 1.8781 - monet_gen_loss: 0.0707 - photo_disc_loss: 1.9466 - photo_gen_loss: -0.0781
Epoch 25/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2498 - identity_loss: 0.0748 - monet_disc_loss: 1.8932 - monet_gen_loss: 0.0769 - photo_disc_loss: 1.9557 - photo_gen_loss: -0.0849
FID after epoch 25: 12.48 Checkpoint saved at epoch 25 Epoch 26/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2281 - identity_loss: 0.0719 - monet_disc_loss: 1.8748 - monet_gen_loss: 0.0817 - photo_disc_loss: 1.9568 - photo_gen_loss: -0.0844
Epoch 27/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2441 - identity_loss: 0.0738 - monet_disc_loss: 1.8643 - monet_gen_loss: 0.0895 - photo_disc_loss: 1.9501 - photo_gen_loss: -0.0799
Epoch 28/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2121 - identity_loss: 0.0725 - monet_disc_loss: 1.8571 - monet_gen_loss: 0.0848 - photo_disc_loss: 1.9544 - photo_gen_loss: -0.0832
Epoch 29/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.2062 - identity_loss: 0.0717 - monet_disc_loss: 1.8840 - monet_gen_loss: 0.1039 - photo_disc_loss: 1.9516 - photo_gen_loss: -0.0805
Epoch 30/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1955 - identity_loss: 0.0697 - monet_disc_loss: 1.8878 - monet_gen_loss: 0.0737 - photo_disc_loss: 1.9481 - photo_gen_loss: -0.0785
FID after epoch 30: 12.58 Checkpoint saved at epoch 30 Epoch 31/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1681 - identity_loss: 0.0675 - monet_disc_loss: 1.8837 - monet_gen_loss: 0.0943 - photo_disc_loss: 1.9463 - photo_gen_loss: -0.0776
Epoch 32/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1575 - identity_loss: 0.0677 - monet_disc_loss: 1.8447 - monet_gen_loss: 0.0820 - photo_disc_loss: 1.9500 - photo_gen_loss: -0.0813
Epoch 33/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1623 - identity_loss: 0.0686 - monet_disc_loss: 1.8933 - monet_gen_loss: 0.0897 - photo_disc_loss: 1.9463 - photo_gen_loss: -0.0780
Epoch 34/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1436 - identity_loss: 0.0685 - monet_disc_loss: 1.8595 - monet_gen_loss: 0.0854 - photo_disc_loss: 1.9435 - photo_gen_loss: -0.0777
Epoch 35/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1245 - identity_loss: 0.0662 - monet_disc_loss: 1.8752 - monet_gen_loss: 0.0850 - photo_disc_loss: 1.9399 - photo_gen_loss: -0.0766
FID after epoch 35: 12.13 Checkpoint saved at epoch 35 Epoch 36/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1158 - identity_loss: 0.0648 - monet_disc_loss: 1.8744 - monet_gen_loss: 0.0959 - photo_disc_loss: 1.9429 - photo_gen_loss: -0.0790
Epoch 37/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1121 - identity_loss: 0.0648 - monet_disc_loss: 1.8879 - monet_gen_loss: 0.0996 - photo_disc_loss: 1.9368 - photo_gen_loss: -0.0750
Epoch 38/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1159 - identity_loss: 0.0657 - monet_disc_loss: 1.8953 - monet_gen_loss: 0.1019 - photo_disc_loss: 1.9394 - photo_gen_loss: -0.0791
Epoch 39/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0877 - identity_loss: 0.0622 - monet_disc_loss: 1.8638 - monet_gen_loss: 0.0844 - photo_disc_loss: 1.9387 - photo_gen_loss: -0.0772
Epoch 40/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0759 - identity_loss: 0.0633 - monet_disc_loss: 1.8849 - monet_gen_loss: 0.0880 - photo_disc_loss: 1.9396 - photo_gen_loss: -0.0767
FID after epoch 40: 12.07 Checkpoint saved at epoch 40 Epoch 41/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0685 - identity_loss: 0.0616 - monet_disc_loss: 1.8895 - monet_gen_loss: 0.0771 - photo_disc_loss: 1.9374 - photo_gen_loss: -0.0744
Epoch 42/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0560 - identity_loss: 0.0610 - monet_disc_loss: 1.8960 - monet_gen_loss: 0.0577 - photo_disc_loss: 1.9376 - photo_gen_loss: -0.0751
Epoch 43/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.0490 - identity_loss: 0.0595 - monet_disc_loss: 1.8891 - monet_gen_loss: 0.0897 - photo_disc_loss: 1.9340 - photo_gen_loss: -0.0718
Epoch 44/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0371 - identity_loss: 0.0594 - monet_disc_loss: 1.8753 - monet_gen_loss: 0.0718 - photo_disc_loss: 1.9316 - photo_gen_loss: -0.0732
Epoch 45/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0224 - identity_loss: 0.0577 - monet_disc_loss: 1.8724 - monet_gen_loss: 0.0921 - photo_disc_loss: 1.9306 - photo_gen_loss: -0.0733
FID after epoch 45: 11.71 Checkpoint saved at epoch 45 Epoch 46/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 158s 2s/step - cycle_loss: 1.0282 - identity_loss: 0.0587 - monet_disc_loss: 1.8925 - monet_gen_loss: 0.0838 - photo_disc_loss: 1.9364 - photo_gen_loss: -0.0743
Epoch 47/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0333 - identity_loss: 0.0602 - monet_disc_loss: 1.9029 - monet_gen_loss: 0.0654 - photo_disc_loss: 1.9359 - photo_gen_loss: -0.0737
Epoch 48/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0065 - identity_loss: 0.0577 - monet_disc_loss: 1.9047 - monet_gen_loss: 0.0804 - photo_disc_loss: 1.9323 - photo_gen_loss: -0.0707
Epoch 49/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0054 - identity_loss: 0.0582 - monet_disc_loss: 1.8804 - monet_gen_loss: 0.0907 - photo_disc_loss: 1.9293 - photo_gen_loss: -0.0702
Epoch 50/50 100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 0.9815 - identity_loss: 0.0562 - monet_disc_loss: 1.8664 - monet_gen_loss: 0.0931 - photo_disc_loss: 1.9293 - photo_gen_loss: -0.0697
FID after epoch 50: 12.03 Checkpoint saved at epoch 50
Results and Visualization¶
We present the generated Monet-style images from the trained model, including visual comparisons and performance evaluation metrics.
Sample Generated Images:¶
Insert visualizations showcasing photos alongside their generated Monet-style counterparts.
Evaluation Metrics:¶
- Qualitative assessment (visual inspection)
- Quantitative metrics (e.g., FID score if used)
Visualizations:¶
- Histograms showing the distribution of pixel intensities before and after style transfer.
- Loss Curves displaying training stability and convergence.
# ========================
# 6. Final Save
# ========================
print("Saving final models...")
monet_generator.save("monet_generator_final.h5")
photo_generator.save("photo_generator_final.h5")
# Plot FID progression
if len(fids) > 0:
plt.figure(figsize=(10, 5))
plt.plot(np.arange(0, EPOCHS, FID_INTERVAL)[0:len(fids)], fids)
plt.title("FID Score During Training")
plt.xlabel("Epochs")
plt.ylabel("FID")
plt.savefig("fid_progress.png")
plt.show()
Saving final models...
ds_iter = iter(photo_ds)
for n_sample in range(3):
example_sample = next(ds_iter)
generated_sample = monet_generator(example_sample)
f = plt.figure(figsize=(32, 32))
plt.subplot(121)
plt.title('Input image')
plt.imshow(example_sample[0] * 0.5 + 0.5)
plt.axis('off')
plt.subplot(122)
plt.title('Generated image')
plt.imshow(generated_sample[0] * 0.5 + 0.5)
plt.axis('off')
plt.show()
# Create a temporary directory if it doesn't exist
os.makedirs("../tmp", exist_ok=True)
zip_path = "../tmp/images.zip"
# Create a function to generate and process images in batches
def generate_monet_images(photo_dataset, generator_model, batch_size=32):
"""
Generates Monet-style images from a photo dataset and saves them directly to a zip file.
Args:
photo_dataset: TensorFlow dataset of photos
generator_model: Trained Monet generator model
batch_size: Number of images to process at once
"""
start_time = time.time()
# Check if dataset is already batched
# We won't re-batch it if it's already batched
dataset = photo_dataset
# Create zip file
with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf:
image_count = 0
# Add a simple progress bar
for batch in tqdm(dataset):
# Handle potential batching issues by checking the tensor rank
tensor_rank = len(batch.shape)
# If tensor has too many dimensions, we need to reshape
if tensor_rank > 4: # Should be [batch, height, width, channels]
print(f"Warning: Input tensor has shape {batch.shape}, reshaping...")
# Assuming the first two dimensions are batch dimensions
batch_size = batch.shape[0] * batch.shape[1]
batch = tf.reshape(batch, [batch_size, *batch.shape[2:]])
print(f"Reshaped to {batch.shape}")
# Generate Monet-style images
generated_images = generator_model(batch, training=False)
# Convert from [-1, 1] to [0, 255] uint8
generated_images = ((generated_images.numpy() * 0.5 + 0.5) * 255).astype(np.uint8)
# Process each image in the batch
for img in generated_images:
# Convert to PIL Image
pil_img = Image.fromarray(img)
# Create in-memory buffer
img_buffer = io.BytesIO()
# Save as JPEG to buffer
pil_img.save(img_buffer, format='JPEG', quality=95)
# Reset buffer position
img_buffer.seek(0)
# Add to zip file with sequential naming
img_name = f"{image_count + 1}.jpg"
zipf.writestr(img_name, img_buffer.getvalue())
image_count += 1
# Optional: Add status update every 1000 images
if image_count % 1000 == 0:
elapsed = time.time() - start_time
print(f"Generated {image_count} images in {elapsed:.2f} seconds")
elapsed = time.time() - start_time
print(f"Generation complete! Created {image_count} images in {elapsed:.2f} seconds")
print(f"Zip file saved to: {zip_path}")
# Optional: Calculate zip file size
zip_size_mb = os.path.getsize(zip_path) / (1024 * 1024)
print(f"Zip file size: {zip_size_mb:.2f} MB")
return image_count
# Copy zip to final Kaggle output location
def finalize_submission():
"""Copies the zip file to the Kaggle output location"""
import shutil
output_path = "images.zip"
shutil.copy(zip_path, output_path)
print(f"Submission file copied to: {output_path}")
# Verify contents
with zipfile.ZipFile(output_path, 'r') as zipf:
file_count = len(zipf.namelist())
print(f"Verified zip contains {file_count} images")
# Prepare a proper dataset for generation
# CRITICAL: Don't use load_dataset that already batches the data
# We need to handle the batching carefully here
def prepare_photo_dataset(photo_files, batch_size=32):
# Load the TFRecord dataset without prior batching
dataset = tf.data.TFRecordDataset(photo_files)
# Parse the records
def read_tfrecord(example):
tfrecord_format = {"image": tf.io.FixedLenFeature([], tf.string)}
example = tf.io.parse_single_example(example, tfrecord_format)
image = tf.image.decode_jpeg(example['image'], channels=3)
image = (tf.cast(image, tf.float32) / 127.5) - 1
image = tf.reshape(image, [256, 256, 3])
return image
# Map the parsing function and batch
dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
# 1. Prepare the dataset properly
photo_ds_for_generation = prepare_photo_dataset(PHOTO_TFREC)
# 2. Generate images and create zip
total_images = generate_monet_images(photo_ds_for_generation, monet_generator)
# 3. Finalize submission
finalize_submission()
Generated 1000 images in 6.63 seconds Generated 2000 images in 12.81 seconds Generated 3000 images in 18.47 seconds Generated 4000 images in 24.20 seconds Generated 5000 images in 29.92 seconds Generated 6000 images in 35.60 seconds Generated 7000 images in 41.10 seconds Generation complete! Created 7038 images in 42.12 seconds Zip file saved to: ../tmp/images.zip Zip file size: 253.20 MB Submission file copied to: images.zip Verified zip contains 7038 images
Discussion and Analysis of Results¶
We analyze the strengths and limitations of our CycleGAN implementation.
What Worked Well:¶
- Effective style transfer, preserving content while capturing Monet's color palette and brushstroke patterns.
- Stable training due to Cycle consistency and ResNet architecture.
- Top ten(!!!!!) for the Kaggle Competition
Challenges and Troubleshooting:¶
- Addressed occasional mode collapse or artifacts by adjusting hyperparameters, such as learning rate and batch size.
- Experimented with additional image augmentations and data normalization techniques.
Improvements for Future Work:¶
- Employing larger models or deeper architectures to improve detail preservation.
- Further experimenting with different loss functions or regularization techniques.
- Leveraging quantitative metrics to rigorously evaluate style transfer quality.
- Training on different artists to see if training architecture can be universal for these tasks.
Key Learnings and Takeaways:¶
- Importance of careful hyperparameter tuning.
- Significance of image normalization and preprocessing in GAN training stability.
- CycleGAN’s suitability for unpaired style-transfer tasks demonstrated by effective generation results.