GAN Project: Monet Paintings Generation Using CycleGAN¶

Brief Description of the Problem and Data¶

In this project, we explore a Generative Adversarial Network (GAN) approach—specifically CycleGAN—to generate Monet-style paintings from photographic images. The main challenge is learning an unsupervised mapping between two distinct image domains: Monet paintings and photos.

Dataset Source:

  • Citation: Amy Jang, Ana Sofia Uzsoy, and Phil Culliton. "I’m Something of a Painter Myself." Kaggle GAN Getting Started Competition, 2020. Kaggle.

The dataset comprises:

  • Monet images: 300 images of Monet paintings, each sized 256 x 256 pixels.
  • Photo images: 7,000 real-world photos, each sized 256 x 256 pixels.

Both image sets are provided in JPEG format and represent distinct visual styles, offering significant variation in color, texture, and content.

  • Kaggle link: https://www.kaggle.com/code/rohanxaviergupta/cycleganmonet
In [ ]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import GroupNormalization
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.callbacks import Callback
import numpy as np
import re
import os
import random
from tensorflow.keras.applications import InceptionV3
import PIL
from PIL import Image
import zipfile
import io
import time
from tqdm.notebook import tqdm
try:
    from kaggle_datasets import KaggleDatasets
except:
    pass
In [ ]:
# CONSTANTS
BATCH_SIZE = 32
IMAGE_SIZE = [256, 256]
EPOCHS = 50
STEPS_PER_EPOCH = 100
FID_INTERVAL = 5  
In [ ]:
# Create directories

# Create all required directories
dirs_to_create = [
    'cache',               # For dataset caching
    'cache/monet',         # Specific cache directories 
    'cache/photo',         # Specific cache directories
    'cyclegan_checkpoints',# For model checkpoints
    'generated_samples',   # For saving generated images
    'logs'                 # For tensorboard logs if needed
]

# Create each directory
for directory in dirs_to_create:
    os.makedirs(directory, exist_ok=True)
    print(f"Directory {directory} is ready")

# Function to safely use cache paths
def get_cache_path(name):
    """Returns a valid cache path for the given name, ensuring the directory exists."""
    path = os.path.join('cache', name)
    os.makedirs(path, exist_ok=True)
    return path
Directory cache is ready
Directory cache/monet is ready
Directory cache/photo is ready
Directory cyclegan_checkpoints is ready
Directory generated_samples is ready
Directory logs is ready
In [ ]:
# Hardware Accelaration
try:
    # 1. Detect TPU hardware
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()  # Auto-detects TPU
    
    # 2. Initialize TPU system
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    
    # 3. Create distributed training strategy
    strategy = tf.distribute.TPUStrategy(tpu)
    
    print(f"TPU detected: {tpu.cluster_spec().as_dict()['worker']}")
    print(f"Number of TPU cores: {strategy.num_replicas_in_sync}")

except ValueError:
    # Fallback to GPU/CPU
    strategy = tf.distribute.get_strategy()
    print(f"Using {strategy.__class__.__name__} (CPU/GPU)")

print(f"Number of replicas: {strategy.num_replicas_in_sync}")

AUTOTUNE = tf.data.experimental.AUTOTUNE
Using _DefaultDistributionStrategy (CPU/GPU)
Number of replicas: 1
In [ ]:
# Set Batch size if we have a TPU 
# (NOTE BATCH_SIZE must be divisble by 8 for 8 core TPU)
BATCH_SIZE = BATCH_SIZE * strategy.num_replicas_in_sync

Exploratory Data Analysis (EDA): Inspecting, Visualizing, and Cleaning the Data¶

In this section, we perform initial inspections, visualize image samples from both Monet paintings and photos, and discuss data preparation steps.

Inspecting the Data¶

  • Check for missing or corrupted images.
  • Confirm the dimensions and formats of the images.

Data Visualization¶

Below, we visualize sample images from each domain to understand visual differences and domain characteristics.

Include visualizations here (e.g., matplotlib figures showing samples of images from each dataset)

Data Cleaning and Preparation¶

  • Resized images to 256x256 pixels (if not already standardized).
  • Verified image normalization (scaling pixel values to [0,1] range).

Based on the EDA, our plan includes:

  • Ensuring consistent image dimensions.
  • Implementing appropriate image augmentations if needed.
  • Normalizing pixel values to facilitate stable training.
In [ ]:
# Load and prepare data
# data_path = 'data'

# for kaggle nb
data_path = KaggleDatasets().get_gcs_path("gan-getting-started")


MONET_JPG= tf.io.gfile.glob(str(data_path + '/monet_jpg/*.jpg'))
PHOTO_JPG = tf.io.gfile.glob(str(data_path + '/photo_jpg/*.jpg'))
MONET_TFREC = tf.io.gfile.glob(str(data_path + '/monet_tfrec/*.tfrec'))
PHOTO_TFREC = tf.io.gfile.glob(str(data_path + '/photo_tfrec/*.tfrec'))
In [ ]:
def decode_image(image):
    """
    Decode JPEG: Converts raw bytes to a uint8 tensor.

    Normalize: Scales pixel values from [0, 255] to [-1, 1] (standard for GANs).

    Reshape: Forces images to 256x256x3 (CycleGAN’s default input size).
    """
    image = tf.image.decode_jpeg(image, channels=3)      # Decode JPEG bytes
    image = (tf.cast(image, tf.float32) / 127.5) - 1     # Normalize to [-1, 1]
    image = tf.reshape(image, [*IMAGE_SIZE, 3])         # Resize to 256x256
    return image

def read_tfrecord(example):
    """
    Parses a single TFRecord example to extract the image field (stored as bytes).

    Passes the bytes to decode_image for preprocessing.
    """
    tfrecord_format = {"image": tf.io.FixedLenFeature([], tf.string)}
    example = tf.io.parse_single_example(example, tfrecord_format)
    return decode_image(example['image'])
In [ ]:
def load_dataset(filenames, is_tfrec=True, image_size=IMAGE_SIZE):
    """
    Load images from either TFRecords or directory of JPG/PNG files.
    
    Args:
        filenames: List of file paths (TFRecords) or directory path (for images)
        is_tfrec: Boolean flag indicating if input is TFRecords
        image_size: Target size for resizing images
    """
    if is_tfrec:
        # Handle TFRecord files
        dataset = tf.data.TFRecordDataset(filenames)
        dataset = dataset.map(
            read_tfrecord, 
            num_parallel_calls=tf.data.AUTOTUNE
        )
    else:
        # Handle JPG/PNG images from directory
        def process_image(file_path):
            img = tf.io.read_file(file_path)
            img = tf.image.decode_jpeg(img, channels=3)
            img = tf.image.resize(img, image_size)
            img = (tf.cast(img, tf.float32) / 127.5) - 1  # Normalize to [-1, 1]
            return img
        
        dataset = tf.data.Dataset.list_files(filenames + "/*.jpg")
        dataset = dataset.map(
            process_image, 
            num_parallel_calls=tf.data.AUTOTUNE
        )
    
    return dataset

# Load datasets
monet_ds = load_dataset(MONET_TFREC).batch(1)
photo_ds = load_dataset(PHOTO_TFREC).batch(1)

# Batches scaled by strategy.num_replicas_in_sync (for TPU/GPU parallelism). prefetch(32) overlaps data loading with training to avoid bottlenecks
fast_photo_ds = load_dataset(PHOTO_TFREC).batch(32*strategy.num_replicas_in_sync).prefetch(32)

# Subset (take(1024)) for Fréchet Inception Distance (FID) calculation. Larger batches (32 * replicas) for efficient evaluation.
fid_photo_ds = load_dataset(PHOTO_TFREC).take(1024).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_monet_ds = load_dataset(MONET_TFREC).batch(32*strategy.num_replicas_in_sync).prefetch(32)
In [ ]:
# Take one sample from each dataset
monet_sample = next(iter(monet_ds.shuffle(10)))
photo_sample = next(iter(photo_ds.shuffle(10)))

# Convert from [-1, 1] range to [0, 1] for visualization
monet_image = monet_sample[0].numpy() * 0.5 + 0.5  # [0] accesses first image in batch
photo_image = photo_sample[0].numpy() * 0.5 + 0.5

# Create subplots
plt.figure(figsize=(10, 5))

# Plot Monet painting
plt.subplot(1, 2, 1)
plt.imshow(monet_image)
plt.title("Monet Style Example")
plt.axis('off')

# Plot Photo
plt.subplot(1, 2, 2)
plt.imshow(photo_image)
plt.title("Real Photo Example")
plt.axis('off')

plt.tight_layout()
plt.show()

print("Monet image shape:", monet_image.shape)
print("Photo image shape:", photo_image.shape)
No description has been provided for this image
Monet image shape: (256, 256, 3)
Photo image shape: (256, 256, 3)
In [ ]:
def get_gan_dataset(monet_files, photo_files, repeat=True, shuffle=True, batch_size=1):
    # Re-load raw datasets fully
    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)

    # Shuffle
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)

    # Batch and optimize
    monet_ds = monet_ds.batch(batch_size, drop_remainder=True)
    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)

    # Use cache
    # Use the helper function for cache paths
    monet_ds = monet_ds.cache(get_cache_path('monet'))
    photo_ds = photo_ds.cache(get_cache_path('photo'))
   
    # Repeat indefinitely for epochs
    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds = photo_ds.repeat()

    monet_ds = monet_ds.prefetch(AUTOTUNE)
    photo_ds = photo_ds.prefetch(AUTOTUNE)

    # Pair Monet and Photo batches
    gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
    return gan_ds
In [ ]:
training_dataset = get_gan_dataset(
    MONET_TFREC, # path to Monet paintings
    PHOTO_TFREC, # path to real photos (domain B)
    repeat=True, # loop dataset for multi-epoch training
    shuffle=True, # shuffle data
    batch_size=BATCH_SIZE
    )
In [ ]:
def build_inception_feature_extractor():
    """
    Creates a modified InceptionV3 model for FID feature extraction.
    Returns:
        tf.keras.Model: Feature extractor using InceptionV3's "mixed9" layer outputs.
    """
    # Load pre-trained InceptionV3 without classification head
    inception = tf.keras.applications.InceptionV3(
        include_top=False,  # Remove final classification layer
        pooling="avg",  # Add global average pooling after last conv layer
        input_shape=(256, 256, 3)  # Match CycleGAN's image size
    )
    
    # Extract intermediate features from "mixed9" layer
    # Why "mixed9"? It captures high-level features before final pooling
    mix9 = inception.get_layer("mixed9").output  # Shape: (None, 8, 8, 2048)
    
    # Additional pooling to reduce spatial dimensions
    features = layers.GlobalAveragePooling2D()(mix9)  # Shape: (None, 2048)
    
    # Build final feature extraction model
    return tf.keras.Model(inputs=inception.input, outputs=features)

def calculate_activation_statistics(dataset, fid_model):
    """
    Computes mean and covariance matrix of feature vectors from a dataset.
    
    Args:
        dataset (tf.data.Dataset): Batched dataset of images (shape: [None, 256, 256, 3])
        fid_model (tf.keras.Model): Feature extractor model
    
    Returns:
        tuple: (mu, sigma) - Mean vector and covariance matrix
    """
    # Initialize lists to collect activations
    all_activations = []
    
    # Process dataset in batches
    for batch in dataset:
        # Extract features for current batch
        act = fid_model(batch)
        all_activations.append(act)
    
    # Concatenate all activations
    act_matrix = tf.concat(all_activations, axis=0)
    
    # Calculate mean and covariance
    mu = tf.reduce_mean(act_matrix, axis=0)
    sigma = tfp.stats.covariance(act_matrix, sample_axis=0)  # Requires tensorflow-probability
    
    return mu, sigma

# -------------------- Execution -------------------- 
with strategy.scope(): # TPU/GPU integration
    
    # 1. Initialize feature extractor
    inception_model = build_inception_feature_extractor()
    inception_model.trainable = False
    
    # 2. Precompute real image statistics
    # Ensure fid_monet_ds is properly batched
    fid_monet_ds = load_dataset(MONET_TFREC).batch(32).prefetch(AUTOTUNE)
    
    # Calculate statistics
    myFID_mu_real, myFID_sigma_real = calculate_activation_statistics(
        fid_monet_ds,  # receives proper image tensors
        inception_model
    )
    
    # 3. Initialize FID tracking list
    fids = []
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
87910968/87910968 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
In [ ]:
with strategy.scope():
    
    def calculate_frechet_distance(mu1, sigma1, mu2, sigma2):
        """Computes the Fréchet Distance between two multivariate Gaussians."""
        # Squared L2 norm of mean difference
        diff = mu1 - mu2
        mean_term = tf.reduce_sum(diff**2)  # ||μ₁ - μ₂||²

        # Covariance term: Tr(Σ₁ + Σ₂ - 2√(Σ₁Σ₂))
        cov_product = tf.matmul(sigma1, sigma2)
        cov_product = tf.cast(cov_product, tf.complex64)  # For sqrtm
        covmean = tf.linalg.sqrtm(cov_product)
        covmean = tf.math.real(covmean)  # Cast back to float32
        covmean = tf.cast(covmean, tf.float32)

        # Avoid NaN gradients by ensuring covmean is finite
        covmean = tf.where(
            tf.math.is_nan(covmean), 
            tf.zeros_like(covmean), 
            covmean
        )

        # Compute trace terms
        tr_covmean = tf.linalg.trace(covmean)
        trace_term = (
            tf.linalg.trace(sigma1) + 
            tf.linalg.trace(sigma2) - 
            2 * tr_covmean
        )

        fid = mean_term + trace_term
        return fid

    def compute_fid(generator, inception_model, real_mu, real_sigma, dataset):
    
        """Computes FID between generated and real images."""
        # Define a tf.function to generate images
        @tf.function
        def generate_images(images):
            return generator(images, training=False)
        
        # Collect all generated activations
        all_activations = []
        for batch in dataset:
            # Generate images
            generated_images = generate_images(batch)
            # Extract features
            activations = inception_model(generated_images)
            all_activations.append(activations)
        
        # Concatenate activations
        gen_activations = tf.concat(all_activations, axis=0)
        
        # Compute generated statistics
        gen_mu = tf.reduce_mean(gen_activations, axis=0)
        gen_sigma = tfp.stats.covariance(gen_activations)
        
        # Calculate FID
        fid_value = calculate_frechet_distance(
            gen_mu, gen_sigma, 
            real_mu, real_sigma
        )
        return fid_value
In [ ]:
def down_sample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))
    if apply_instancenorm:
        result.add(GroupNormalization(groups=-1))
    result.add(layers.LeakyReLU())
    return result

def up_sample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                      kernel_initializer=initializer, use_bias=False))
    result.add(GroupNormalization(groups=-1))
    if apply_dropout:
        result.add(layers.Dropout(0.5))
    result.add(layers.ReLU())
    return result

Model Architecture¶

We selected the CycleGAN architecture for this problem, primarily because:

  • Unsupervised Learning: CycleGAN allows image-to-image translation without paired datasets, making it ideal for artistic style transfer scenarios.
  • Cycle Consistency Loss: This ensures the model preserves content details while converting styles, essential for generating realistic Monet-style images from photos.

Architecture Overview:¶

  • Generators: Utilize ResNet-based generators for stable training and effective feature extraction.

Input (256x256x3)

│

├─Downsample 64 → 128x128x64

├─Downsample 128 → 64x64x128

├─Downsample 256 → 32x32x256

├─Downsample 512 → 16x16x512

├─Downsample 512 → 8x8x512

├─Downsample 512 → 4x4x512

├─Downsample 512 → 2x2x512

└─Downsample 512 → 1x1x512 (bottleneck)

│

├─Upsample 512 → 2x2x512 (with skip from 2x2x512)

├─Upsample 512 → 4x4x512 (with skip from 4x4x512)

├─Upsample 512 → 8x8x512 (with skip from 8x8x512)

├─Upsample 512 → 16x16x512 (with skip from 16x16x512)

├─Upsample 256 → 32x32x256 (with skip from 32x32x256)

├─Upsample 128 → 64x64x128 (with skip from 64x64x128)

├─Upsample 64 → 128x128x64 (with skip from 128x128x64)

│

└─Output (256x256x3)

  • Discriminators: Employ PatchGAN discriminators to classify local image patches, which helps the model capture detailed textures and patterns characteristic of Monet paintings.

Input Image (256x256x3)

│

├─Downsample 64 → 128x128x64

├─Downsample 128 → 64x64x128

├─Downsample 256 → 32x32x256

│

├─ZeroPad2D → 34x34x256

├─Conv2D 512 → 31x31x512

├─InstanceNorm + LeakyReLU

└─ZeroPad2D → 33x33x512 (Base Discriminator Output)

│

├─DHead:

  ├─Conv2D 1 → 30x30x1 (Patch Predictions)

In [ ]:
def Generator():
    '''
    1. Downsampling Path (Encoder)
    Purpose: Compresses the input image into a low-dimensional bottleneck.

    Layers:

    8 down_sample blocks with increasing filters (64 → 512).

    Each block reduces spatial resolution by half (stride=2).

    First block omits instance normalization to preserve low-level details.

    Output: 1x1x512 bottleneck tensor.

    2. Upsampling Path (Decoder)
    Purpose: Reconstructs the image from the bottleneck to the target domain.

    Layers:

    7 up_sample blocks with decreasing filters (512 → 64).

    Transposed convolutions (stride=2) double spatial resolution.

    Dropout (50%) in early layers to prevent overfitting.

    Skip Connections: Concatenate features from the encoder to the decoder (U-Net structure).

    3. Final Output Layer
    Conv2DTranspose:

    Output channels: 3 (RGB).

    tanh activation: Normalizes outputs to [-1, 1], matching input normalization.
    '''
    inputs = layers.Input(shape=[256, 256, 3])
    
    # Downsampling path
    down_stack = [
        down_sample(64, 4, apply_instancenorm=False),  # 256x256 → 128x128
        down_sample(128, 4),                            # 128x128 → 64x64
        down_sample(256, 4),                            # 64x64 → 32x32
        down_sample(512, 4),                            # 32x32 → 16x16
        down_sample(512, 4),                            # 16x16 → 8x8
        down_sample(512, 4),                            # 8x8 → 4x4
        down_sample(512, 4),                            # 4x4 → 2x2
        down_sample(512, 4),                            # 2x2 → 1x1 (bottleneck)
    ]

    # Upsampling path
    up_stack = [
        up_sample(512, 4, apply_dropout=True),  # 1x1 → 2x2
        up_sample(512, 4, apply_dropout=True),   # 2x2 → 4x4
        up_sample(512, 4, apply_dropout=True),   # 4x4 → 8x8
        up_sample(512, 4),                       # 8x8 → 16x16
        up_sample(256, 4),                       # 16x16 → 32x32
        up_sample(128, 4),                       # 32x32 → 64x64
        up_sample(64, 4),                        # 64x64 → 128x128
    ]

    # Final output layer
    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(
        3, 4, strides=2, padding='same',
        kernel_initializer=initializer, activation='tanh'
    )  # 128x128 → 256x256

    x = inputs
    skips = []
    
    # Downsampling
    for down in down_stack:
        x = down(x)
        skips.append(x)
    
    skips = reversed(skips[:-1])  # Omit the bottleneck layer
    
    # Upsampling with skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])
    
    x = last(x)
    return keras.Model(inputs=inputs, outputs=x)
In [ ]:
def Discriminator():
    '''
    Discriminator Architecture (Base)
    Purpose: Feature extractor for real/fake classification

    Structure:

    3 Downsample Blocks: Reduce resolution while increasing filters (64→256)

    Zero Padding: Expands spatial dimensions for subsequent convolutions

    Final Conv Layer: 512 filters with stride 1 (no resolution change)

    Output Shape: 33x33x512 feature maps

    2. DHead (Decision Head)
    Purpose: Final classification layer for adversarial loss

    Structure:

    Input: 33x33x512 features from base discriminator

    Conv2D(1): Reduces to 30x30x1 "patch" predictions

    No Activation: Raw logits for different loss functions

    3. Design Choices
    Separate Heads: Allows using different loss functions:

    dHead1: Binary cross-entropy (BCE)

    dHead2: Hinge loss

    Instance Normalization: Stabilizes training by normalizing features per-image

    Zero Padding: Preserves spatial dimensions after convolutions

    LeakyReLU: (α=0.2) prevents dead neurons in discriminator
    '''
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
    inp = layers.Input(shape=[256, 256, 3], name='input_image')
    x = inp
    
    # Downsampling Path
    down1 = down_sample(64, 4, False)(x)       # 256x256 → 128x128
    down2 = down_sample(128, 4)(down1)         # 128x128 → 64x64
    down3 = down_sample(256, 4)(down2)         # 64x64 → 32x32
    
    # Final Layers
    zero_pad1 = layers.ZeroPadding2D()(down3)  # 32x32 → 34x34
    conv = layers.Conv2D(512, 4, strides=1, 
                        kernel_initializer=initializer, 
                        use_bias=False)(zero_pad1)  # 34x34 → 31x31
    norm1 = GroupNormalization(groups=-1, gamma_initializer=gamma_init)(conv)
    leaky_relu = layers.LeakyReLU()(norm1)
    zero_pad2 = layers.ZeroPadding2D()(leaky_relu)  # 31x31 → 33x33
    
    return keras.Model(inputs=inp, outputs=zero_pad2)

def DHead():
    initializer = tf.random_normal_initializer(0., 0.02)
    
    inp = layers.Input(shape=[33, 33, 512], name='input_image')
    x = inp
    last = layers.Conv2D(1, 4, strides=1, 
                        kernel_initializer=initializer)(x)  # 33x33 → 30x30
    return keras.Model(inputs=inp, outputs=last)
In [ ]:
with strategy.scope():
    def DiffAugment(x, policy='', channels_first=False):
        if policy:
            if channels_first:
                x = tf.transpose(x, [0, 2, 3, 1])
            for p in policy.split(','):
                for f in AUGMENT_FNS[p]:
                    x = f(x)
            if channels_first:
                x = tf.transpose(x, [0, 3, 1, 2])
        return x


    def rand_brightness(x):
        magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5
        x = x + magnitude
        return x


    def rand_saturation(x):
        magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2
        x_mean = tf.reduce_sum(x, axis=3, keepdims=True) * 0.3333333333333333333
        x = (x - x_mean) * magnitude + x_mean
        return x


    def rand_contrast(x):
        magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5
        x_mean = tf.reduce_sum(x, axis=[1, 2, 3], keepdims=True) * 5.086e-6
        x = (x - x_mean) * magnitude + x_mean
        return x

    def rand_translation(x, ratio=0.125):
        batch_size = tf.shape(x)[0]
        image_size = tf.shape(x)[1:3]
        shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
        translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)
        translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)
        grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1)
        grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1)
        x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)
        x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3])
        return x


    def rand_cutout(x, ratio=0.5):
        batch_size = tf.shape(x)[0]
        image_size = tf.shape(x)[1:3]
        cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
        offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32)
        offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32)
        grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij')
        cutout_grid = tf.stack([grid_batch, grid_x + offset_x - cutout_size[0] // 2, grid_y + offset_y - cutout_size[1] // 2], axis=-1)
        mask_shape = tf.stack([batch_size, image_size[0], image_size[1]])
        cutout_grid = tf.maximum(cutout_grid, 0)
        cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3]))
        mask = tf.maximum(1 - tf.scatter_nd(cutout_grid, tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32), mask_shape), 0)
        x = x * tf.expand_dims(mask, axis=3)
        return x


    AUGMENT_FNS = {
        'color': [rand_brightness, rand_saturation, rand_contrast],
        'translation': [rand_translation],
        'cutout': [rand_cutout],
}
    def aug_fn(image):
        return DiffAugment(image,"color,translation,cutout")
In [ ]:
class CycleGan(keras.Model):
    def __init__(self, m_gen, p_gen, m_disc, p_disc, dhead1=None, dhead2=None, lambda_cycle=3, lambda_id=3):
        super().__init__()
        self.m_gen = m_gen  # Monet generator (photos → paintings)
        self.p_gen = p_gen  # Photo generator (paintings → photos)
        self.m_disc = m_disc  # Monet discriminator (base)
        self.p_disc = p_disc  # Photo discriminator (full)
        self.dhead1 = dhead1  # First discriminator head
        self.dhead2 = dhead2  # Second discriminator head (can be None)
        self.lambda_cycle = lambda_cycle  # Cycle consistency weight
        self.lambda_id = lambda_id  # Identity loss weight

    def compile(self, m_gen_opt, p_gen_opt, m_disc_opt, p_disc_opt, 
               gen_loss_fn1, gen_loss_fn2, disc_loss_fn1, disc_loss_fn2,
               cycle_loss_fn, identity_loss_fn, aug_fn):
        super().compile()
        # Optimizers
        self.m_gen_opt = m_gen_opt
        self.p_gen_opt = p_gen_opt
        self.m_disc_opt = m_disc_opt
        self.p_disc_opt = p_disc_opt
        
        # Loss Functions
        self.gen_loss_fn1 = gen_loss_fn1  # e.g., BCE
        self.gen_loss_fn2 = gen_loss_fn2  # e.g., Hinge
        self.disc_loss_fn1 = disc_loss_fn1
        self.disc_loss_fn2 = disc_loss_fn2
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
        # Augmentation
        self.aug_fn = aug_fn  # DiffAugment policy

    def augment_batch(self, real_images, fake_images):
        """Apply data augmentation to both real and fake images."""
        # Concatenate images for batched augmentation (more efficient)
        combined = tf.concat([real_images, fake_images], axis=0)
        
        # Apply augmentation
        augmented = self.aug_fn(combined)
        
        # Split back into real and fake
        batch_size = tf.shape(real_images)[0]
        aug_real = augmented[:batch_size]
        aug_fake = augmented[batch_size:]
        
        return aug_real, aug_fake

    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        batch_size = tf.shape(real_monet)[0]
        
        with tf.GradientTape(persistent=True) as tape:
            # Forward cycle: photo → monet → photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)
            
            # Backward cycle: monet → photo → monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)
            
            # Identity mapping
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)
            
            # Apply augmentation if provided
            aug_real_monet, aug_fake_monet = self.augment_batch(real_monet, fake_monet)
            aug_real_photo, aug_fake_photo = self.augment_batch(real_photo, fake_photo)
            
            # Monet discriminator
            disc_real_monet_features = self.m_disc(aug_real_monet, training=True)
            disc_fake_monet_features = self.m_disc(aug_fake_monet, training=True)
            
            # Initialize loss values
            monet_gen_loss = 0
            monet_disc_loss = 0
            monet_gen_loss2 = 0
            monet_disc_loss2 = 0
            
            # Use discriminator head if available
            if self.dhead1 is not None:
                disc_real_monet = self.dhead1(disc_real_monet_features, training=True)
                disc_fake_monet = self.dhead1(disc_fake_monet_features, training=True)
                monet_gen_loss = self.gen_loss_fn1(disc_fake_monet)
                monet_disc_loss = self.disc_loss_fn1(disc_real_monet, disc_fake_monet)
                
                # Second head (optional)
                if self.dhead2 is not None:
                    disc_real_monet2 = self.dhead2(disc_real_monet_features, training=True)
                    disc_fake_monet2 = self.dhead2(disc_fake_monet_features, training=True)
                    monet_gen_loss2 = self.gen_loss_fn2(disc_fake_monet2)
                    monet_disc_loss2 = self.disc_loss_fn2(disc_real_monet2, disc_fake_monet2)
            else:
                # Use features directly (patch discriminator)
                disc_real_monet = disc_real_monet_features
                disc_fake_monet = disc_fake_monet_features
                monet_gen_loss = self.gen_loss_fn1(disc_real_monet)
                monet_disc_loss = self.disc_loss_fn1(disc_real_monet, disc_fake_monet)
            
            # Photo discriminator
            disc_real_photo = self.p_disc(aug_real_photo, training=True)
            disc_fake_photo = self.p_disc(aug_fake_photo, training=True)
            photo_gen_loss = self.gen_loss_fn1(disc_fake_photo)
            photo_disc_loss = self.disc_loss_fn1(disc_real_photo, disc_fake_photo)
            
            # Cycle consistency loss
            cycle_loss = (
                self.cycle_loss_fn(real_monet, cycled_monet) + 
                self.cycle_loss_fn(real_photo, cycled_photo)
            ) * self.lambda_cycle
            
            # Identity loss
            identity_loss = (
                self.identity_loss_fn(real_monet, same_monet) + 
                self.identity_loss_fn(real_photo, same_photo)
            ) * self.lambda_id
            
            # Total losses
            total_monet_gen_loss = monet_gen_loss + monet_gen_loss2 + cycle_loss + identity_loss
            total_photo_gen_loss = photo_gen_loss + cycle_loss + identity_loss
            total_monet_disc_loss = monet_disc_loss + monet_disc_loss2
        
        # Calculate gradients
        monet_gen_grads = tape.gradient(total_monet_gen_loss, self.m_gen.trainable_variables)
        photo_gen_grads = tape.gradient(total_photo_gen_loss, self.p_gen.trainable_variables)
        
        # Apply generator gradients directly
        self.m_gen_opt.apply_gradients(zip(monet_gen_grads, self.m_gen.trainable_variables))
        self.p_gen_opt.apply_gradients(zip(photo_gen_grads, self.p_gen.trainable_variables))
        
        # Handle discriminator (complete networks at once instead of separating heads)
        # This avoids unnecessary computation in graph mode
        all_m_disc_vars = self.m_disc.trainable_variables
        if self.dhead1 is not None:
            all_m_disc_vars += self.dhead1.trainable_variables
        if self.dhead2 is not None:
            all_m_disc_vars += self.dhead2.trainable_variables
        
        monet_disc_grads = tape.gradient(total_monet_disc_loss, all_m_disc_vars)
        photo_disc_grads = tape.gradient(photo_disc_loss, self.p_disc.trainable_variables)
        
        self.m_disc_opt.apply_gradients(zip(monet_disc_grads, all_m_disc_vars))
        self.p_disc_opt.apply_gradients(zip(photo_disc_grads, self.p_disc.trainable_variables))
        
        # Return metrics
        return {
            "monet_gen_loss": monet_gen_loss,
            "photo_gen_loss": photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss,
            "cycle_loss": cycle_loss,
            "identity_loss": identity_loss
        }

Building and Training the CycleGAN¶

Here, we detail the training setup, including loss functions, optimizers, epochs, and computational resources:

  • Cycle Consistency Loss
  • Adversarial Loss
  • Identity Loss (optional)
  • Optimizers: Adam optimizer with learning rate scheduler
  • Training duration: Number of epochs and batch sizes
In [ ]:
with strategy.scope():
    # ========================
    # 1. Initialize Models
    # ========================
    monet_generator = Generator()  # Photos → Monet
    photo_generator = Generator()  # Monet → Photos
    
    monet_discriminator = Discriminator()
    photo_discriminator = Discriminator()
    dhead1 = DHead()  # For BCE loss
    dhead2 = DHead()  # For Hinge loss

    # ========================
    # 2. Define Optimizers
    # ========================
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    # ========================
    # 3. Define Loss Functions
    # ========================
    def generator_loss1(generated):
        return tf.reduce_mean(-generated)  # Hinge loss`

    def generator_loss2(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True)(
            tf.ones_like(generated), generated)

    def discriminator_loss1(real, generated):
        real_loss = tf.reduce_mean(tf.minimum(0., -1. + real))
        fake_loss = tf.reduce_mean(tf.minimum(0., -1. - generated))
        return -tf.reduce_mean(real_loss + fake_loss)

    def discriminator_loss2(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
            tf.ones_like(real), real)
        fake_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
            tf.zeros_like(generated), generated)
        return 0.5 * (real_loss + fake_loss)
    
    def cycle_loss_fn(real, cycled):
        return tf.reduce_mean(tf.abs(real - cycled))
    
    def identity_loss_fn(real, same):
        return tf.reduce_mean(tf.abs(real - same))

    # ========================
    # 4. Compile CycleGAN
    # ========================
    cycle_gan = CycleGan(
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        dhead1,
        dhead2,
        lambda_cycle=10,  # Stronger cycle consistency
        lambda_id=0.5     # Weaker identity loss
    )

    cycle_gan.compile(
        m_gen_opt=monet_generator_optimizer,
        p_gen_opt=photo_generator_optimizer,
        m_disc_opt=monet_discriminator_optimizer,
        p_disc_opt=photo_discriminator_optimizer,
        gen_loss_fn1=generator_loss1,
        gen_loss_fn2=generator_loss2,
        disc_loss_fn1=discriminator_loss1,
        disc_loss_fn2=discriminator_loss2,
        cycle_loss_fn=cycle_loss_fn,
        identity_loss_fn=identity_loss_fn,
        aug_fn=aug_fn
    )

# ========================
# 4.5 Monitoring
# ========================
class CycleGANMonitor(Callback):
    def __init__(self, sample_photo, sample_monet, monet_generator, photo_generator, epoch_interval=5):
        """
        Args:
            sample_photo: Batch of sample photos (normalized to [-1, 1])
            sample_monet: Batch of sample Monet paintings (normalized to [-1, 1])
            monet_generator: Generator that converts photos to Monet style
            photo_generator: Generator that converts Monet to photos
            epoch_interval: How often to generate samples (in epochs)
        """
        self.sample_photo = sample_photo
        self.sample_monet = sample_monet
        self.monet_generator = monet_generator
        self.photo_generator = photo_generator
        self.epoch_interval = epoch_interval

    def _denormalize(self, image):
        """Convert from [-1, 1] range to [0, 1] for visualization"""
        return (image * 0.5) + 0.5
    @tf.function
    def generate_predictions(self, images, generator):
        return generator(images, training=False)

    def _plot_predictions(self, epoch=None):
        # Generate predictions with tf.function
        monet_pred = self.generate_predictions(self.sample_photo, self.monet_generator)
        photo_pred = self.generate_predictions(self.sample_monet, self.photo_generator)

        plt.figure(figsize=(18, 8))
        
        # Photo → Monet translations
        for i in range(min(3, len(self.sample_photo))):  # Show first 3 samples
            plt.subplot(2, 3, i+1)
            plt.imshow(self._denormalize(self.sample_photo[i]))
            plt.title(f"Input Photo {i+1}")
            plt.axis("off")
            
            plt.subplot(2, 3, i+4)
            plt.imshow(self._denormalize(monet_pred[i]))
            plt.title(f"Generated Monet {i+1}" + (f" (Epoch {epoch})" if epoch else ""))
            plt.axis("off")

        plt.tight_layout()
        plt.show()

    def on_epoch_end(self, epoch, logs=None):
        """Generate samples at specified intervals"""
        if (epoch+1) % self.epoch_interval == 0:  # +1 to avoid epoch 0
            self._plot_predictions(epoch+1)

    def on_train_end(self, logs=None):
        """Final visualization after training"""
        self._plot_predictions()

# Prepare sample images
sample_photo = next(iter(photo_ds.take(1)))  # Get 1 batch of photos
sample_monet = next(iter(monet_ds.take(1)))  # Get 1 batch of Monet paintings

# Create callback
viz_callback = CycleGANMonitor(
    sample_photo=sample_photo,
    sample_monet=sample_monet,
    monet_generator=monet_generator,
    photo_generator=photo_generator,
    epoch_interval=5  # Generate samples every 5 epochs
)
In [ ]:
# ========================
# 5. Training Loop
# ========================
fids = []  # Track FID scores during training

for epoch in range(1, EPOCHS+1):
    print(f"Epoch {epoch}/{EPOCHS}")
    
    # Train for one epoch
    history = cycle_gan.fit(
        training_dataset,
        epochs=1,
        steps_per_epoch=STEPS_PER_EPOCH,
        callbacks=[viz_callback],  # Use the visualization callback
        verbose=1
    )
    
    # Periodic Evaluation
    if epoch % FID_INTERVAL == 0:
        try:
            # Calculate FID
            fid_score = compute_fid(
                monet_generator,
                inception_model,
                myFID_mu_real,
                myFID_sigma_real,
                fid_photo_ds.take(64)  # Use subset for faster evaluation
            )
            fids.append(fid_score.numpy())
            print(f"FID after epoch {epoch}: {fid_score:.2f}")
            
            # Save checkpoint
            checkpoint_dir = f"./cyclegan_checkpoints/epoch_{epoch}"
            checkpoint = tf.train.Checkpoint(
                monet_generator=monet_generator,
                photo_generator=photo_generator,
                monet_generator_optimizer=monet_generator_optimizer,
                photo_generator_optimizer=photo_generator_optimizer
            )
            checkpoint.save(checkpoint_dir)
            print(f"Checkpoint saved at epoch {epoch}")
        except Exception as e:
            print(f"Error during FID calculation: {e}")
            continue

# Learning rate decay after halfway
if epoch == EPOCHS // 2:
    print("Reducing learning rate by factor of 10")
    # Reduce learning rates
    monet_generator_optimizer.learning_rate = monet_generator_optimizer.learning_rate * 0.1
    photo_generator_optimizer.learning_rate = photo_generator_optimizer.learning_rate * 0.1
    monet_discriminator_optimizer.learning_rate = monet_discriminator_optimizer.learning_rate * 0.1
    photo_discriminator_optimizer.learning_rate = photo_discriminator_optimizer.learning_rate * 0.1
Epoch 1/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 264s 2s/step - cycle_loss: 4.2368 - identity_loss: 0.2133 - monet_disc_loss: 1.9284 - monet_gen_loss: 0.0400 - photo_disc_loss: 1.9992 - photo_gen_loss: -0.0053
No description has been provided for this image
Epoch 2/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 2.7125 - identity_loss: 0.1677 - monet_disc_loss: 1.8535 - monet_gen_loss: 0.0956 - photo_disc_loss: 1.9968 - photo_gen_loss: -0.0091
No description has been provided for this image
Epoch 3/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 2.2031 - identity_loss: 0.1707 - monet_disc_loss: 1.9045 - monet_gen_loss: 0.0614 - photo_disc_loss: 1.9953 - photo_gen_loss: -0.0135
No description has been provided for this image
Epoch 4/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 2.0119 - identity_loss: 0.1814 - monet_disc_loss: 1.9373 - monet_gen_loss: 0.0678 - photo_disc_loss: 1.9938 - photo_gen_loss: -0.0174
No description has been provided for this image
Epoch 5/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.9083 - identity_loss: 0.1691 - monet_disc_loss: 1.9500 - monet_gen_loss: 0.0609 - photo_disc_loss: 1.9911 - photo_gen_loss: -0.0212
No description has been provided for this image
FID after epoch 5: 16.71
Checkpoint saved at epoch 5
Epoch 6/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.8526 - identity_loss: 0.1543 - monet_disc_loss: 1.9486 - monet_gen_loss: 0.0195 - photo_disc_loss: 1.9869 - photo_gen_loss: -0.0248
No description has been provided for this image
Epoch 7/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.7730 - identity_loss: 0.1360 - monet_disc_loss: 1.9646 - monet_gen_loss: 0.0182 - photo_disc_loss: 1.9858 - photo_gen_loss: -0.0299
No description has been provided for this image
Epoch 8/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.7071 - identity_loss: 0.1201 - monet_disc_loss: 1.9587 - monet_gen_loss: 0.0287 - photo_disc_loss: 1.9818 - photo_gen_loss: -0.0325
No description has been provided for this image
Epoch 9/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.6600 - identity_loss: 0.1137 - monet_disc_loss: 1.9580 - monet_gen_loss: 0.0231 - photo_disc_loss: 1.9775 - photo_gen_loss: -0.0353
No description has been provided for this image
Epoch 10/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.5796 - identity_loss: 0.0990 - monet_disc_loss: 1.9690 - monet_gen_loss: 0.0013 - photo_disc_loss: 1.9743 - photo_gen_loss: -0.0389
No description has been provided for this image
FID after epoch 10: 16.70
Checkpoint saved at epoch 10
Epoch 11/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.5617 - identity_loss: 0.0966 - monet_disc_loss: 1.7614 - monet_gen_loss: 0.1003 - photo_disc_loss: 1.9690 - photo_gen_loss: -0.0407
No description has been provided for this image
Epoch 12/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.5055 - identity_loss: 0.1001 - monet_disc_loss: 1.7904 - monet_gen_loss: 0.1000 - photo_disc_loss: 1.9641 - photo_gen_loss: -0.0427
No description has been provided for this image
Epoch 13/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.4833 - identity_loss: 0.0965 - monet_disc_loss: 1.6997 - monet_gen_loss: 0.1375 - photo_disc_loss: 1.9584 - photo_gen_loss: -0.0439
No description has been provided for this image
Epoch 14/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.4315 - identity_loss: 0.0896 - monet_disc_loss: 1.8767 - monet_gen_loss: 0.0828 - photo_disc_loss: 1.9511 - photo_gen_loss: -0.0437
No description has been provided for this image
Epoch 15/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.4219 - identity_loss: 0.0884 - monet_disc_loss: 1.7719 - monet_gen_loss: 0.0870 - photo_disc_loss: 1.9478 - photo_gen_loss: -0.0466
No description has been provided for this image
FID after epoch 15: 13.20
Checkpoint saved at epoch 15
Epoch 16/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.4021 - identity_loss: 0.0877 - monet_disc_loss: 1.8049 - monet_gen_loss: 0.0964 - photo_disc_loss: 1.9465 - photo_gen_loss: -0.0502
No description has been provided for this image
Epoch 17/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.3972 - identity_loss: 0.0883 - monet_disc_loss: 1.7924 - monet_gen_loss: 0.1015 - photo_disc_loss: 1.9407 - photo_gen_loss: -0.0509
No description has been provided for this image
Epoch 18/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.3660 - identity_loss: 0.0839 - monet_disc_loss: 1.7914 - monet_gen_loss: 0.1003 - photo_disc_loss: 1.9401 - photo_gen_loss: -0.0550
No description has been provided for this image
Epoch 19/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.3485 - identity_loss: 0.0808 - monet_disc_loss: 1.8488 - monet_gen_loss: 0.0876 - photo_disc_loss: 1.9319 - photo_gen_loss: -0.0537
No description has been provided for this image
Epoch 20/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.3210 - identity_loss: 0.0785 - monet_disc_loss: 1.8822 - monet_gen_loss: 0.0548 - photo_disc_loss: 1.9278 - photo_gen_loss: -0.0554
No description has been provided for this image
FID after epoch 20: 12.65
Checkpoint saved at epoch 20
Epoch 21/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.3120 - identity_loss: 0.0772 - monet_disc_loss: 1.8385 - monet_gen_loss: 0.0617 - photo_disc_loss: 1.9252 - photo_gen_loss: -0.0577
No description has been provided for this image
Epoch 22/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2865 - identity_loss: 0.0777 - monet_disc_loss: 1.8299 - monet_gen_loss: 0.0877 - photo_disc_loss: 1.9303 - photo_gen_loss: -0.0649
No description has been provided for this image
Epoch 23/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2622 - identity_loss: 0.0751 - monet_disc_loss: 1.8757 - monet_gen_loss: 0.0631 - photo_disc_loss: 1.9340 - photo_gen_loss: -0.0689
No description has been provided for this image
Epoch 24/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2536 - identity_loss: 0.0743 - monet_disc_loss: 1.8781 - monet_gen_loss: 0.0707 - photo_disc_loss: 1.9466 - photo_gen_loss: -0.0781
No description has been provided for this image
Epoch 25/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2498 - identity_loss: 0.0748 - monet_disc_loss: 1.8932 - monet_gen_loss: 0.0769 - photo_disc_loss: 1.9557 - photo_gen_loss: -0.0849
No description has been provided for this image
FID after epoch 25: 12.48
Checkpoint saved at epoch 25
Epoch 26/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2281 - identity_loss: 0.0719 - monet_disc_loss: 1.8748 - monet_gen_loss: 0.0817 - photo_disc_loss: 1.9568 - photo_gen_loss: -0.0844
No description has been provided for this image
Epoch 27/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2441 - identity_loss: 0.0738 - monet_disc_loss: 1.8643 - monet_gen_loss: 0.0895 - photo_disc_loss: 1.9501 - photo_gen_loss: -0.0799
No description has been provided for this image
Epoch 28/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.2121 - identity_loss: 0.0725 - monet_disc_loss: 1.8571 - monet_gen_loss: 0.0848 - photo_disc_loss: 1.9544 - photo_gen_loss: -0.0832
No description has been provided for this image
Epoch 29/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.2062 - identity_loss: 0.0717 - monet_disc_loss: 1.8840 - monet_gen_loss: 0.1039 - photo_disc_loss: 1.9516 - photo_gen_loss: -0.0805
No description has been provided for this image
Epoch 30/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1955 - identity_loss: 0.0697 - monet_disc_loss: 1.8878 - monet_gen_loss: 0.0737 - photo_disc_loss: 1.9481 - photo_gen_loss: -0.0785
No description has been provided for this image
FID after epoch 30: 12.58
Checkpoint saved at epoch 30
Epoch 31/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1681 - identity_loss: 0.0675 - monet_disc_loss: 1.8837 - monet_gen_loss: 0.0943 - photo_disc_loss: 1.9463 - photo_gen_loss: -0.0776
No description has been provided for this image
Epoch 32/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1575 - identity_loss: 0.0677 - monet_disc_loss: 1.8447 - monet_gen_loss: 0.0820 - photo_disc_loss: 1.9500 - photo_gen_loss: -0.0813
No description has been provided for this image
Epoch 33/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1623 - identity_loss: 0.0686 - monet_disc_loss: 1.8933 - monet_gen_loss: 0.0897 - photo_disc_loss: 1.9463 - photo_gen_loss: -0.0780
No description has been provided for this image
Epoch 34/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1436 - identity_loss: 0.0685 - monet_disc_loss: 1.8595 - monet_gen_loss: 0.0854 - photo_disc_loss: 1.9435 - photo_gen_loss: -0.0777
No description has been provided for this image
Epoch 35/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1245 - identity_loss: 0.0662 - monet_disc_loss: 1.8752 - monet_gen_loss: 0.0850 - photo_disc_loss: 1.9399 - photo_gen_loss: -0.0766
No description has been provided for this image
FID after epoch 35: 12.13
Checkpoint saved at epoch 35
Epoch 36/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1158 - identity_loss: 0.0648 - monet_disc_loss: 1.8744 - monet_gen_loss: 0.0959 - photo_disc_loss: 1.9429 - photo_gen_loss: -0.0790
No description has been provided for this image
Epoch 37/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1121 - identity_loss: 0.0648 - monet_disc_loss: 1.8879 - monet_gen_loss: 0.0996 - photo_disc_loss: 1.9368 - photo_gen_loss: -0.0750
No description has been provided for this image
Epoch 38/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.1159 - identity_loss: 0.0657 - monet_disc_loss: 1.8953 - monet_gen_loss: 0.1019 - photo_disc_loss: 1.9394 - photo_gen_loss: -0.0791
No description has been provided for this image
Epoch 39/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0877 - identity_loss: 0.0622 - monet_disc_loss: 1.8638 - monet_gen_loss: 0.0844 - photo_disc_loss: 1.9387 - photo_gen_loss: -0.0772
No description has been provided for this image
Epoch 40/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0759 - identity_loss: 0.0633 - monet_disc_loss: 1.8849 - monet_gen_loss: 0.0880 - photo_disc_loss: 1.9396 - photo_gen_loss: -0.0767
No description has been provided for this image
FID after epoch 40: 12.07
Checkpoint saved at epoch 40
Epoch 41/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0685 - identity_loss: 0.0616 - monet_disc_loss: 1.8895 - monet_gen_loss: 0.0771 - photo_disc_loss: 1.9374 - photo_gen_loss: -0.0744
No description has been provided for this image
Epoch 42/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0560 - identity_loss: 0.0610 - monet_disc_loss: 1.8960 - monet_gen_loss: 0.0577 - photo_disc_loss: 1.9376 - photo_gen_loss: -0.0751
No description has been provided for this image
Epoch 43/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 160s 2s/step - cycle_loss: 1.0490 - identity_loss: 0.0595 - monet_disc_loss: 1.8891 - monet_gen_loss: 0.0897 - photo_disc_loss: 1.9340 - photo_gen_loss: -0.0718
No description has been provided for this image
Epoch 44/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0371 - identity_loss: 0.0594 - monet_disc_loss: 1.8753 - monet_gen_loss: 0.0718 - photo_disc_loss: 1.9316 - photo_gen_loss: -0.0732
No description has been provided for this image
Epoch 45/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0224 - identity_loss: 0.0577 - monet_disc_loss: 1.8724 - monet_gen_loss: 0.0921 - photo_disc_loss: 1.9306 - photo_gen_loss: -0.0733
No description has been provided for this image
FID after epoch 45: 11.71
Checkpoint saved at epoch 45
Epoch 46/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 158s 2s/step - cycle_loss: 1.0282 - identity_loss: 0.0587 - monet_disc_loss: 1.8925 - monet_gen_loss: 0.0838 - photo_disc_loss: 1.9364 - photo_gen_loss: -0.0743
No description has been provided for this image
Epoch 47/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0333 - identity_loss: 0.0602 - monet_disc_loss: 1.9029 - monet_gen_loss: 0.0654 - photo_disc_loss: 1.9359 - photo_gen_loss: -0.0737
No description has been provided for this image
Epoch 48/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0065 - identity_loss: 0.0577 - monet_disc_loss: 1.9047 - monet_gen_loss: 0.0804 - photo_disc_loss: 1.9323 - photo_gen_loss: -0.0707
No description has been provided for this image
Epoch 49/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 1.0054 - identity_loss: 0.0582 - monet_disc_loss: 1.8804 - monet_gen_loss: 0.0907 - photo_disc_loss: 1.9293 - photo_gen_loss: -0.0702
No description has been provided for this image
Epoch 50/50
100/100 ━━━━━━━━━━━━━━━━━━━━ 159s 2s/step - cycle_loss: 0.9815 - identity_loss: 0.0562 - monet_disc_loss: 1.8664 - monet_gen_loss: 0.0931 - photo_disc_loss: 1.9293 - photo_gen_loss: -0.0697
No description has been provided for this image
FID after epoch 50: 12.03
Checkpoint saved at epoch 50

Results and Visualization¶

We present the generated Monet-style images from the trained model, including visual comparisons and performance evaluation metrics.

Sample Generated Images:¶

Insert visualizations showcasing photos alongside their generated Monet-style counterparts.

Evaluation Metrics:¶

  • Qualitative assessment (visual inspection)
  • Quantitative metrics (e.g., FID score if used)

Visualizations:¶

  • Histograms showing the distribution of pixel intensities before and after style transfer.
  • Loss Curves displaying training stability and convergence.
In [ ]:
# ========================
# 6. Final Save
# ========================
print("Saving final models...")
monet_generator.save("monet_generator_final.h5")
photo_generator.save("photo_generator_final.h5")

# Plot FID progression
if len(fids) > 0:
    plt.figure(figsize=(10, 5))
    plt.plot(np.arange(0, EPOCHS, FID_INTERVAL)[0:len(fids)], fids)
    plt.title("FID Score During Training")
    plt.xlabel("Epochs")
    plt.ylabel("FID")
    plt.savefig("fid_progress.png")
    plt.show()
Saving final models...
No description has been provided for this image
In [ ]:
ds_iter = iter(photo_ds)
for n_sample in range(3):
        example_sample = next(ds_iter)
        generated_sample = monet_generator(example_sample)
        
        f = plt.figure(figsize=(32, 32))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [ ]:
# Create a temporary directory if it doesn't exist
os.makedirs("../tmp", exist_ok=True)
zip_path = "../tmp/images.zip"

# Create a function to generate and process images in batches
def generate_monet_images(photo_dataset, generator_model, batch_size=32):
    """
    Generates Monet-style images from a photo dataset and saves them directly to a zip file.
    
    Args:
        photo_dataset: TensorFlow dataset of photos
        generator_model: Trained Monet generator model
        batch_size: Number of images to process at once
    """
    start_time = time.time()
    
    # Check if dataset is already batched
    # We won't re-batch it if it's already batched
    dataset = photo_dataset
    
    # Create zip file
    with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf:
        image_count = 0
        
        # Add a simple progress bar
        for batch in tqdm(dataset):
            # Handle potential batching issues by checking the tensor rank
            tensor_rank = len(batch.shape)
            
            # If tensor has too many dimensions, we need to reshape
            if tensor_rank > 4:  # Should be [batch, height, width, channels]
                print(f"Warning: Input tensor has shape {batch.shape}, reshaping...")
                # Assuming the first two dimensions are batch dimensions
                batch_size = batch.shape[0] * batch.shape[1]
                batch = tf.reshape(batch, [batch_size, *batch.shape[2:]])
                print(f"Reshaped to {batch.shape}")
            
            # Generate Monet-style images
            generated_images = generator_model(batch, training=False)
            
            # Convert from [-1, 1] to [0, 255] uint8
            generated_images = ((generated_images.numpy() * 0.5 + 0.5) * 255).astype(np.uint8)
            
            # Process each image in the batch
            for img in generated_images:
                # Convert to PIL Image
                pil_img = Image.fromarray(img)
                
                # Create in-memory buffer
                img_buffer = io.BytesIO()
                
                # Save as JPEG to buffer
                pil_img.save(img_buffer, format='JPEG', quality=95)
                
                # Reset buffer position
                img_buffer.seek(0)
                
                # Add to zip file with sequential naming
                img_name = f"{image_count + 1}.jpg"
                zipf.writestr(img_name, img_buffer.getvalue())
                
                image_count += 1
                
                # Optional: Add status update every 1000 images
                if image_count % 1000 == 0:
                    elapsed = time.time() - start_time
                    print(f"Generated {image_count} images in {elapsed:.2f} seconds")
    
    elapsed = time.time() - start_time
    print(f"Generation complete! Created {image_count} images in {elapsed:.2f} seconds")
    print(f"Zip file saved to: {zip_path}")
    
    # Optional: Calculate zip file size
    zip_size_mb = os.path.getsize(zip_path) / (1024 * 1024)
    print(f"Zip file size: {zip_size_mb:.2f} MB")
    
    return image_count

# Copy zip to final Kaggle output location
def finalize_submission():
    """Copies the zip file to the Kaggle output location"""
    import shutil
    output_path = "images.zip"
    shutil.copy(zip_path, output_path)
    print(f"Submission file copied to: {output_path}")
    
    # Verify contents
    with zipfile.ZipFile(output_path, 'r') as zipf:
        file_count = len(zipf.namelist())
        print(f"Verified zip contains {file_count} images")

# Prepare a proper dataset for generation
# CRITICAL: Don't use load_dataset that already batches the data
# We need to handle the batching carefully here
def prepare_photo_dataset(photo_files, batch_size=32):
    # Load the TFRecord dataset without prior batching
    dataset = tf.data.TFRecordDataset(photo_files)
    
    # Parse the records
    def read_tfrecord(example):
        tfrecord_format = {"image": tf.io.FixedLenFeature([], tf.string)}
        example = tf.io.parse_single_example(example, tfrecord_format)
        image = tf.image.decode_jpeg(example['image'], channels=3)
        image = (tf.cast(image, tf.float32) / 127.5) - 1
        image = tf.reshape(image, [256, 256, 3])
        return image
    
    # Map the parsing function and batch
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

# 1. Prepare the dataset properly
photo_ds_for_generation = prepare_photo_dataset(PHOTO_TFREC)

# 2. Generate images and create zip
total_images = generate_monet_images(photo_ds_for_generation, monet_generator)

# 3. Finalize submission
finalize_submission()
Generated 1000 images in 6.63 seconds
Generated 2000 images in 12.81 seconds
Generated 3000 images in 18.47 seconds
Generated 4000 images in 24.20 seconds
Generated 5000 images in 29.92 seconds
Generated 6000 images in 35.60 seconds
Generated 7000 images in 41.10 seconds
Generation complete! Created 7038 images in 42.12 seconds
Zip file saved to: ../tmp/images.zip
Zip file size: 253.20 MB
Submission file copied to: images.zip
Verified zip contains 7038 images

Discussion and Analysis of Results¶

We analyze the strengths and limitations of our CycleGAN implementation.

What Worked Well:¶

  • Effective style transfer, preserving content while capturing Monet's color palette and brushstroke patterns.
  • Stable training due to Cycle consistency and ResNet architecture.
  • Top ten(!!!!!) for the Kaggle Competition

Challenges and Troubleshooting:¶

  • Addressed occasional mode collapse or artifacts by adjusting hyperparameters, such as learning rate and batch size.
  • Experimented with additional image augmentations and data normalization techniques.

Improvements for Future Work:¶

  • Employing larger models or deeper architectures to improve detail preservation.
  • Further experimenting with different loss functions or regularization techniques.
  • Leveraging quantitative metrics to rigorously evaluate style transfer quality.
  • Training on different artists to see if training architecture can be universal for these tasks.

Key Learnings and Takeaways:¶

  • Importance of careful hyperparameter tuning.
  • Significance of image normalization and preprocessing in GAN training stability.
  • CycleGAN’s suitability for unpaired style-transfer tasks demonstrated by effective generation results.