The Path to StyleGan2 - The Finale

Author

Yusuf Mohammad


This post concludes The Path to StyleGan2 series and will see us implement the StyleGAN2 model (there is a StyleGAN3 but I will not cover it, for now).

The link for the StyleGAN2 paper is https://arxiv.org/pdf/1912.04958. I’d recommend giving it a read before and alongside this post :)

As I write the training for this model is still underway, if anything it has shown to me the challenge of training GANs. I have ran at least 100 training runs (over the course of four months) which either explode in FID or outright crash due to a segmentation fault. Read on for some of the errors I faced and what to look out for in GAN training!

Also, our implementation here is a little unique due to image size and the number of images we can fit in our GPU. To help out, I made use of a few other implementations of the StyleGAN2 to aid my discussion below and I urge you to check them out too they are great, also these posts have quite a lot of cool insights which warrant further reading.

GPU NOTE For this implementation you will require at least a 12gb GPU, however you can edit it to fit your GPU10.

Links to other repos/posts:

We make use of the FFHQ dataset this time around, with 256x256 images (as I am GPU poor). You can find the dataset at https://www.kaggle.com/datasets/denislukovnikov/ffhq256-images-only/data. It is a similar dataset to CelebA-HQ-256 the key difference being it includes more diverse images.

Lastly, I’d like to say that the StyleGAN series of models are amazing, I hope you enjoyed this series and I thank you for taking the time to read these posts!

Now let’s get into StyleGAN2


Expand this block to show code for imports and some helper functions :)

Code
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import torch
import torchvision

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
from torchvision import datasets, transforms, utils

from torch.utils.data import DataLoader

from torchmetrics.image.fid import FrechetInceptionDistance

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from datetime import datetime
from math import sqrt
import math
import sys
import random


# We can make use of a GPU if you have one on your computer. This works for Nvidia and M series GPU's
if torch.backends.mps.is_available():
    device = torch.device("mps")
    # These 2 lines assign some data on the memory of the device and output it. The output confirms
    # if we have set the intended device
    x = torch.ones(1, device=device)
    print (x)
elif torch.backends.cuda.is_built():
    device = torch.device("cuda")
    x = torch.ones(1, device=device)
    print (x)
else:
    device = ("cpu")
    x = torch.ones(1, device=device)
    print (x)
    
# I also define a function we use to examine the outputs of the Generator
def show_images(images, num_images=16, figsize=(10,10)):
    # Ensure the input is on CPU
    images = images.cpu().detach()
    
    # Normalize images from [-1, 1] to [0, 1]
    images = (images + 1) / 2
    
    # Clamp values to [0, 1] range
    images = torch.clamp(images, 0, 1)
    
    # Make a grid of images
    grid = torchvision.utils.make_grid(images[:num_images], nrow=4)
    
    # Convert to numpy and transpose
    grid = grid.numpy().transpose((1, 2, 0))
    
    # Display the grid
    plt.figure(figsize=figsize)
    plt.imshow(grid)
    plt.axis('off')
    plt.show()
tensor([1.], device='cuda:0')

StyleGAN2 - A modification of the original StyleGAN

StyleGAN2 is an adaptation of StyleGAN, if you read the StyleGAN post (shameless self-plug alert: if you haven’t I suggest you stop here and check it out) you will discover today that StyleGAN2 takes many elements of that model and adapts them to improve the quality of generated images. Also, the images from StyleGAN have two specific issues (I didn’t see this in my images, but this is because I didn’t train my StyleGAN to convergence), the issues arise from Adaptive Instance Normalisation (AdaIN) and the progressive growing regime. Let’s explore these before we move on to the implementation.

Issue 1 - Adaptive Instance Normalisation

The first issue we want to fix is “water droplets”. Take a look at Figure 1, they appear as distortions in specific areas of the images. The authors state this is related to the Adaptive Instance Normalisation (AdaIN) operation and when AdaIN is removed these artefacts do not appear1. AdaIN normalises the mean and variance of each image separately, potentially destroying any information found in magnitudes of the features relative to each other. The “water droplet” effect maybe an attempt of the G model to sneak signal strength information past the AdaIN by creating localised spikes which will dominate the statistics2. Now to stop this effect we could just remove AdaIN and the authors state this does stop the water dropletting and reduces FID slightly, but a key strength of StyleGAN is the “Style” transfer which takes place currently in the AdaIN operation. We need another way to enact style transfer but without AdaIN.

Figure 1 - StyleGAN Artefacts

Issue 2 - Progressive Growing

The second issue is a result of the progressive growing regime. This scheme has been in play ever since the first post in this series (I wrote that a while ago). Progressive growing was the key idea which allowed the authors to achieve such great image quality. A quick recap, we first focus on lower level details by starting off training on 4x4 images, then we upscale to 8x8 -> 16x16 -> 32x32 -> … -> 1024x1024 in the end (I only go up to 256x256 due to GPU constraints) for more in depth coverage refer to https://ym2132.github.io/Progressive_GAN. Anyways, the problem can be seen in Figure 2, they name these type of issues “phase artifacts”. It’s a subtle problem but look at the teeth, they stay aligned with the camera rather than the pose which is not how it works in the real world. This issue is supposedly caused by the way we implement progressive growing, the G model has a strong location preference for specific details (e.g. here teeth) and this is because we focus on low level details first before moving to higher level ones, so the detail will be in one place for a while and then randomly jump in the growing to a more preferred location. Whereas we want them to move smoothly overtime. To resolve this we scrap a part of the old progressive growing approach, namely where we change the network structure throughout training. Instead we still grow the image progressively, but the output is always the final resolution so from iteration 1 we go from a 4x4 image to 1024x1024 (or in our case 256x256). This retains the stability of progressive growing without the phase artifacts.

Figure 2 - Progressive Growing Issue

Let’s take a deeper dive on the plan to rectify these problems before we get coding :)


Redesigning the Normalisation

The first major change in StyleGAN2 is a redesign of the normalisation, let’s get into it!

Remember the goal is to remove AdaIN but retain the ability to perform style transfer. Refer to Figure 3 for an overview of what is about to happen.

Figure 3 - StyleGAN2 New Style Block Setup. (a) is the original StyleGAN block, (b) is the original StyleGAN block but split up so that each block has only one conv operation. (c) is the new style block for StyleGAN2 and (d) is the same as (c) but with the modulation operation expanded.

So, we replace AdaIN with a new operation called modulation (which itself consists of two parts - modulation and demodulation). Before getting to that, let’s understand what the AdaIN operation does. In Figure 3b we show the two parts of AdaIN normalisation and modulation, recall the formula for AdaIN has two parts the normalisation part and the modulation where we add the statistics (mean and standard deviation) of the latent vector used for style transfer. In the original StyleGAN the bias and noise is applied within the style block, causing the relative impact to be inversly proportional to the current style’s magnitude. If you move the operation outside the style block the results are slightly improved, this is because they operate on normalised data. After this change you may also only operate on standard deviation alone, you do not need to modulate the mean of generated images.

The issue of style modulation currently is that it may amplify the magnitudes of certain feature maps more than others, for it to work this amplification needs to be counteracted on a per-sample basis (feature map by feature map). This is what I mentioned earlier where the authors performed an ablation study and found it improves FID if you just remove AdaIN but it’s better to keep style transfer in. The method to do so is to base normalisation on expected statistics of incoming feature maps.

Modulation and Demodulation

Now, let’s explore the new normalisation and style transfer procedure. Review equation (1), this is the formula for modulation. Essentially, we scale each input feature map of the convolution based on the incoming style.

\[ w'_{ijk} = s_i \cdot w_{ijk} \tag{1}\]

\(w\) and \(w'\) are the original and modulated weights respectively, and \(s_{i}\) is the scale corresponding the to the \(i\)th input. The purpose of instance normalisation is then to remove the effect of \(s\) from the statistics of the convolutions output feature maps. This goal can be achieved direclty by equation (2).

\[ w''_{ijk} = w'_{ijk}/\sqrt{\sum_{i,k} {w'_{ijk}}^2 + \epsilon} \tag{2} \]

\(\epsilon\) is a small constant to avoid divide by 0 errors. This demodulation operation is now based upon statistical assumptions about the signal instead of actual contents of the feature maps. Thats it for demodulation! We’ll take a look at the implementation later, it has a pretty neat trick baked into it :)


Removing Issues of Progressive Growing

Progressive growing is one the best things to happen to GANs, as it massively increased stability when training to generate high resolution images. So, despite the issue it causes we don’t want to remove it. The fix is quite simple, we retain progressive growing but unlike in the StyleGAN where we output each resolution at each layer and change the network structure as training progresses, we just output the final resolution and train the entire network right from the start. Now, it is a little more involved as we introduce skip connections and residual connections but overall it’s a nice fix and seems pretty intuitive as a first experiment right?

Examine Figure 4. We introduce skip connections in the G model and residual connections in the D model.

Figure 4 - The New Generator and Discriminator architectures

These new architectures still retain the ability to focus first on lower resolutions before moving to higher ones without progressive growing3. This behaviour is not explicitly forced, so the G only does it if it is beneficial to training (which is quite awesome). Now it’s time to write some code!


The Path to Implementation - A Roadmap

So, let’s lay out a path for us to travel in order to implement the StyleGAN2:

1 - The new style block - modulation and demodulation
2 - G and D architecture redesign - removal of Progressive Growing, skip connections for G and residual connections for D
3 - Lazy + Path Length Regularisation
4 - Changing the loss function (no more WGAN-GP)

We discussed changes 1 and 2 already, but changes 3 and 4 are still a mystery. I think these are best explored in conjunction with the code. 3 especially highlights some core behaviours of GANs.


Expand the following to check out the old code (stuff we covered in the previous blog post). Note in the old blog posts the Mapping Network had hidden dimensions of 256x256 now I use 512x512 to reflect the paper.

Code
# Credit: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py#L94
class EqualLRConv2d(nn.Module):
    def __init__(
        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
    ):
        super().__init__()

        self.weight = nn.Parameter(
            torch.randn(out_channel, in_channel, kernel_size, kernel_size)
        )
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)

        self.stride = stride
        self.padding = padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel))

        else:
            self.bias = None

    def forward(self, input):
        out = F.conv2d(
            input,
            self.weight * self.scale,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
        )

        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
            f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
        )

class EqualLRLinear(nn.Module):
    def __init__(
        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=0.01, activation=None  # lr_mul from rosinality 
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = nn.Parameter(torch.zeros(out_dim))

        self.activation = activation

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        self.lr_mul = lr_mul

    def forward(self, input):
        if self.activation:
            out = F.linear(input, self.weight * self.scale)
            out = fused_leaky_relu(out, self.bias * self.lr_mul)
        else:
            out = F.linear(
                input, self.weight * self.scale, bias=self.bias * self.lr_mul
            )

        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
        )

def EMA(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 

class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True)+ 1e-8)

class MiniBatchStdDev(nn.Module):
    # The original StyleGAN paper states 8, but in StyleGAN2 repo they use 4 when they have batch size of 32
    # I set batch_size=16 and still keep group_size=4.
    def __init__(self, group_size=4):  
        super().__init__()
        self.group_size = group_size
    
    def forward(self, x):
        N, C, H, W = x.shape 
        G = min(self.group_size, N) 
        
        y = x.view(G, -1, C, H, W)
        
        y = y - torch.mean(y, dim=0, keepdim=True)
        y = torch.mean(torch.square(y), dim=0)
        y = torch.sqrt(y + 1e-8)
        
        y = torch.mean(y, dim=[1,2,3], keepdim=True)
        
        y = y.repeat(G, 1, H, W)
        
        return torch.cat([x,y], dim=1)

class LearnedConstant(nn.Module):
    def __init__(self, in_c):
        super().__init__()
        
        self.constant = nn.Parameter(torch.randn(1, in_c, 4, 4))
        
    def forward(self, batch_size):
        return self.constant.expand(batch_size, -1, -1, -1)

class NoiseLayer(nn.Module):
    def __init__(self, channels):
        super().__init__()
        
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
    
    def forward(self, gen_image, noise=None):
        if noise is None:
            N, _, H, W = gen_image.shape
            noise = torch.randn(N, 1, H, W, device=gen_image.device)
        
        return gen_image + (noise * self.weight)

class MappingNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.norm = PixelNorm()
        
        self.layers = nn.Sequential(
            EqualLRLinear(512, 512),
            nn.LeakyReLU(0.2),
            EqualLRLinear(512, 512),
            nn.LeakyReLU(0.2),
            EqualLRLinear(512, 512),
            nn.LeakyReLU(0.2),
            EqualLRLinear(512, 512),
            nn.LeakyReLU(0.2),
            EqualLRLinear(512, 512),
            nn.LeakyReLU(0.2),
            EqualLRLinear(512, 512),
            nn.LeakyReLU(0.2),
            EqualLRLinear(512, 512),
            nn.LeakyReLU(0.2),
            EqualLRLinear(512, 512),
            nn.LeakyReLU(0.2),
        )
    
    def forward(self, x):
        # Normalise input latent
        x = self.norm(x)
        
        out = self.layers(x)
        
        return out

# To get the params for the mapping network we look for parameters with "mapping" in the name
def get_params_with_lr(model):
    mapping_params = []
    other_params = []
    for name, param in model.named_parameters():
        if 'mapping' in name:  # Adjust this condition based on your actual naming convention
            mapping_params.append(param)
        else:
            other_params.append(param)
    return mapping_params, other_params

# Now sample 30k fake and add them to fid
def resize(images):
    # Resize to 299x299, the inception v3 model expects 299,299 images so we just resize our images to
    # this size
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
    ])
    return transform(images)

# We make use of the FID class provided by the torchmetrics library provided by PyTorch Lightning
# This class works by "adding" fake and real images to the model
# So I create two functions, one for adding fake images and one for real images
def add_fake_images(g_running, num_images, batch_size, latent_dim, device):
    # The function takes in g_running and all params needed to generate images
    # we use g_running in this function as it is the model we use to output our fake images
    g_running.eval()
    
    # Set torch.no_grad to turn off gradients, it makes the code run faster and use less memory as gradients
    # aren't tracked
    with torch.no_grad():
        # Generate 70000 images to pass to the FID model, we pass them as we generate them to save
        # images
        for _ in tqdm(range(0, num_images, batch_size), desc="Generating images"):
            z = torch.randn(batch_size, latent_dim, device=device)
            batch_images = g_running(z)
        
            # resize images
            resize_batch = resize(batch_images)
            # Inception v3 requires pixel ranges to be [0,255] currently it's [-1,1], 
            # this line handles the conversion
            resize_batch = ((resize_batch + 1) * 127.5).clamp(0, 255)
            # Inception v3 also expects input data type to be uint8, this can be handled with a simple cast
            resize_batch = resize_batch.to(torch.uint8)
            
            # Update FID
            fid.update(resize_batch, real=False)
            
            # Clear GPU cache, to save memory
            torch.cuda.empty_cache()

# The second function just takes in the data_loader as a parm
def add_real_imgs(data_loader):
    # we pass all batches to the FID model i.e. 70k images
    for batch in tqdm(data_loader, desc="Processing real images"):
        imgs, _ = batch
        # Resize, convert to [0,255] range and cast to uint8 as before
        imgs = resize(imgs)
        imgs = imgs.to(device)
        imgs = ((imgs + 1) * 127.5).clamp(0, 255)
        imgs = imgs.to(torch.uint8)
        fid.update(imgs, real=True)

        del imgs
        torch.cuda.empty_cache()

# Final function which combines both of the image adding functions
def calculate_and_save_fid(iteration, data_loader, g_running, num_fake_images, batch_size, latent_dim, device, fid_file):
    # We reset the FID score statistics for recalculation
    fid.reset()
    # Add the real and fake images
    add_fake_images(g_running, num_fake_images, batch_size, latent_dim, device)
    add_real_imgs(data_loader)
    # Compute FID score and output it
    fid_score = fid.compute()
    print(f"FID score for iteration {iteration}: {fid_score.item()}")
    
    # We also save the scores to a file
    with open(fid_file, 'a') as f:
        f.write(f"Iteration {iteration}: {fid_score.item()}\n")

1 - The New Style Block

This part requires two steps, first we implement the modulation and demodulation. Then we create a new style block to accomodate. Refer to Figure 5 for a mapping between the image and pieces of code.

Figure 5 - Order of Implementation
class Conv2d_mod(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, latent_dim, padding):
        super().__init__()
        
        # Weight initialization for our conv
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
        )
        
        self.scale = 1 / math.sqrt(in_channels * kernel_size ** 2)
        
        # Style modulation layer, just a linear layer. Which takes the incoming style vector w and 
        # aligns it with number of feature maps the network expects
        self.modulation = EqualLRLinear(latent_dim, in_channels, bias_init=1)
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2

    def forward(self, x, style):
        batch, channels, height, width = x.shape
        
        # Style modulation
        style = self.modulation(style).view(batch, 1, -1, 1, 1)
        
        # Scale weights and apply style - here weight = w'
        weight = self.scale * self.weight * style
        
        # Demodulation - Implementation of equation (2)
        demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
        # rsqrt means 1/sqrt, so weight * demod is the same as what is in the equation (2)
        weight = weight * demod.view(batch, self.out_channels, 1, 1, 1)
        
        # Reshape for grouped convolution
        weight = weight.view(
            batch * self.out_channels, self.in_channels, 
            self.kernel_size, self.kernel_size
        )
        
        # Reshape input and perform convolution
        x = x.view(1, batch * channels, height, width)
        out = F.conv2d(
            x, 
            weight,
            padding=self.padding,
            groups=batch
        )
        
        # Reshape output
        _, _, height, width = out.shape
        out = out.view(batch, self.out_channels, height, width)
        
        return out

Thats the mod demod operation done, to better understand some of the reshapes and operations let’s explore a little with a test case.

test_w = torch.rand(4, 512)  # Represents a latent vector w after z is passed through the mapping network
# this w vector has a batch size of 4. In the real network we set batch size to 32 but for illustration we can use 4

# Let's review the old style transfer methodology
map_layer = EqualLRLinear(512, 512*2)  # Remeber that we split the mapped w into 2 for the y_s and y_b hence the out dim is 512x2
style = map_layer(test_w)

y_s, y_b = style.chunk(2, dim=1)
        
# Reshape y_s and y_b to match x's dimensions
y_s = y_s.unsqueeze(2).unsqueeze(3)
y_b = y_b.unsqueeze(2).unsqueeze(3)

print(y_s.shape, y_b.shape)
print('---')
# Now here is the new style transfer code with mod and demod - note this is just to show the different resizings we perform.
map_layer = EqualLRLinear(512, 512)
s = map_layer(test_w)
print(f's shape before view: {s.shape}')
s = s.view(test_w.size(0), -1, 1, 1, 1)  # Expand to match dims of img, we match batch size and feature maps, but add 3 dims for RGB
print(f's shape after view: {s.shape}')

print()

# Example w which is the weights of a conv layer
# w = weights here this is the shape of a convolutional layers weights in PyTorch
w = nn.Parameter(torch.randn(512, 512, 3, 3))
print(f'w shape before mod: {w.shape}')

w = w * s
print(f'Shape of w after modulation: {w.shape}', '\n')

demod = torch.rsqrt(w.pow(2).sum([2,3,4], keepdim=True) + 1e-8)
print(f'w: {w.shape} | demod: {demod.shape}')
w = w * demod
print(f'w shape after demod: {w.shape}')

print('---')

# This covers the first part of code, namely in the mod_demod function.
# Let's also examine the code in the Con2d_mod and also explore how the grouped convolution works

# Conv2d Expects a 4d tensor as it's weight, so we need to convert w to 4d
# current w has 5 dimensions, the first two can be condensed (batch size and feature maps) so we end up with basically
# the feature maps for each sample in the batch stacked on top of each other.
w = nn.Parameter(w.view(4 * 512, 512, 3, 3))
groups = 4

x = torch.randn(4, 512, 4, 4)
# Stack x in the same way we stacked the weights
x = x.view(1, 4*512, 4, 4)

# Set the weights of conv layer to the weights which have been modulated and demodulated.
out = F.conv2d(x, w, groups=groups, padding=1)
out = out.view(4, -1, out.shape[-2], out.shape[-1])
out.shape
torch.Size([4, 512, 1, 1]) torch.Size([4, 512, 1, 1])
---
s shape before view: torch.Size([4, 512])
s shape after view: torch.Size([4, 512, 1, 1, 1])

w shape before mod: torch.Size([512, 512, 3, 3])
Shape of w after modulation: torch.Size([4, 512, 512, 3, 3]) 

w: torch.Size([4, 512, 512, 3, 3]) | demod: torch.Size([4, 512, 1, 1, 1])
w shape after demod: torch.Size([4, 512, 512, 3, 3])
---
torch.Size([4, 512, 4, 4])

There we have the modulation and demodulation operation, this is applied to the weights of the convolutional layers in the G network. Let’s see how it fits in the new style block for the G!

# The actual style block itself isn't much different than that of StyleGAN, except we change the conv2d with
# our newly created conv2d_mod and that's it.
class g_style_block(nn.Module):
    def __init__(
        self, 
        in_c, 
        out_c, 
        ksize1, 
        padding,
        upsample=True,
        latent_dim=512,
    ):
        super().__init__()

        layers_list = []

        if upsample:
            layers_list.extend([
                nn.Upsample(scale_factor=2, mode='bilinear'),
                Conv2d_mod(in_c, out_c, ksize1, latent_dim, padding=padding),
                NoiseLayer(out_c),
            ])
        else:
            self.learned_constant = LearnedConstant(in_c)

        layers_list.extend([
            nn.LeakyReLU(0.2),
            Conv2d_mod(out_c, out_c, ksize1, latent_dim, padding=padding),
            NoiseLayer(out_c),
            nn.LeakyReLU(0.2),
        ])

        self.layers = nn.ModuleList(layers_list)
        self.upsample = upsample
        
    def forward(self, w, x=None):
        if not self.upsample:
            x = self.learned_constant(w.size(0))
        
        for layer in self.layers:
            if isinstance(layer, LearnedConstant):
                x = layer()
            elif isinstance(layer, Conv2d_mod):
                x = layer(x, w)
            else:
                x = layer(x)
            
        return x

While we’re here, we may as well cover the D convolutional block. It differs slightly from the StyleGAN but does not make use of modulation and demodulation.

# The module we create to apply a blur in the D downsample
class Blur(nn.Module):
    def __init__(self):
        super().__init__()
        f = torch.Tensor([1, 3, 3, 1])
        self.register_buffer('f', f)

    def forward(self, x):
        f = self.f
        f = f[None, None, :] * f[None, :, None]
        f = f / f.sum()
        return F.conv2d(x, f.expand(x.size(1), -1, -1, -1), 
                        groups=x.size(1), padding=1)

# D style block is pretty similar too with a few minor changes. We implement the residual connection in this block
# it makes more sense to me to do it here and we examine exactly what happens in the code here and the exposition to follow
class d_style_block(nn.Module):
    def __init__(
        self,
        in_c,
        out_c,
        ksize1, 
        padding,  
        ksize2=None, 
        padding2=None,
        stride=None,   
        mbatch=False,
    ):
        super().__init__()

        if ksize2 is None:
            ksize2 = ksize1
        if padding2 is None:
            padding2 = padding
        
        layers_list = []

        # Instead of using bi-linear interpolation as in the StyleGAN we create a downsample using convolutional layers
        # and we add a gaussian blur too
        self.down_res = nn.Sequential(
            Blur(),
            EqualLRConv2d(in_c, out_c, 3, padding = 1, stride=2),
            EqualLRConv2d(out_c, out_c, 2, padding=0, stride=1) if mbatch else nn.Identity()  # Needed for last layer otherwise res 
            # has two H, W instead of 1
        )
        
        if mbatch:
            layers_list.extend([
                MiniBatchStdDev(),
            ])
            in_c += 1
            
        layers_list.extend([
            EqualLRConv2d(in_c, in_c, ksize1, padding=padding),
            nn.LeakyReLU(0.2),
            EqualLRConv2d(in_c, out_c, ksize2, padding=padding2),
            nn.LeakyReLU(0.2),
        ])

        if not mbatch:  # No downsample if we are on the last layer
            layers_list.extend([
                EqualLRConv2d(out_c, out_c, 3, padding=1, stride=2)
            ])
        
        self.layers = nn.ModuleList(layers_list)
    
    def forward(self, x):
        # A residual connection is made by summing the output of the previous layer with the output of the current layer
        # To do this successfully we need to down sample the output of previous layer and store it
        res = self.down_res(x)
        for layer in self.layers:
            x = layer(x)

        # We perform the summation here, the second term is whats called a variance preservation term. When we sum res + x
        # we also sum the variance of both tensors. Overtime then the variance would grow, so we scale by 1/sqrt(2)
        # E.g: we have res with variance a^2 and x with variance a^2, when we sum them we get out with variance
        # 2a^2, multiplying this by 1/sqrt(2) brings variance back to a^2
        out = (res + x) * (1/ sqrt(2))       
        
        return out

And there we have it, both the G and D style/conv blocks! With those in place it’s time to introduce the new G and D architectures.


2 - The New Architecture

In this section we cover the removal of progressive growing and the addition of residual connection to the D (this was briefly discussed in the code comments for D style block) and skip connections to the G. I will add Figure 4 back here for our reference (it guides the following implementation)

Figure 4b - Repeat with only relevant diagrams

Generator

# Observe in the G how we have removed progressive growing. Look at the "out" variable in line 55, this variable is passed to 
# every layer right from the first 4x4 resolution layer to the final 256x256 resolution layer. We do not need any control
# variables (num_layer or alpha) to grow the network, instead we start from iteration 0 at the full size network.
class Generator(nn.Module):
    def __init__(self, in_c=512):
        super().__init__()
        
        self.g_mapping = MappingNetwork()

        # This fmaps params is another little hidden hint from: https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/train.py#L157
        # This is a repo which came out a while after StyleGAN2 repo, this highlights the difficulty of replicating papers.
        # In this repo they provide params for a 256x256 StyleGAN2 and one of the changes is to half the number of feature maps
        # in the G model.
        fmaps = 0.5
        in_c = int(in_c * fmaps)
        
        self.block_4x4 = g_style_block(in_c, in_c, 3, 1, upsample=False)
        self.block_8x8 = g_style_block(in_c, in_c, 3, 1)
        self.block_16x16 = g_style_block(in_c, in_c, 3, 1)
        self.block_32x32 = g_style_block(in_c, in_c, 3, 1)
        self.block_64x64 = g_style_block(in_c, in_c//2, 3, 1)
        self.block_128x128 = g_style_block(in_c//2, in_c//4, 3, 1)
        self.block_256x256 = g_style_block(in_c//4, in_c//4, 3, 1)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        
        self.to_rgb_4 = EqualLRConv2d(in_c, 3, 1)
        self.to_rgb_8 = EqualLRConv2d(in_c, 3, 1)
        self.to_rgb_16 = EqualLRConv2d(in_c, 3, 1)
        self.to_rgb_32 = EqualLRConv2d(in_c, 3, 1)
        self.to_rgb_64 = EqualLRConv2d(in_c//2, 3, 1)
        self.to_rgb_128 = EqualLRConv2d(in_c//4, 3, 1)
        self.to_rgb_256 = EqualLRConv2d(in_c//4, 3, 1)
                
        self.tanh = nn.Tanh()

    def forward(self, z, return_latents=False):
        w = self.g_mapping(z)
        batch_size = z.size(0)
        
        # Determine which samples undergo style mixing
        # Style mixing is different here than in the StyleGAN post, I think I made a mistake with my previous implementation
        # The mistake being that I applied style mixing across the entire batch, meaning either all samples in the batch 
        # undergo style mixing or none do. This is not what the authors intended, they wanted style mixing to apply per sample
        # in the batch. Essentially in the same batch of 16, 32 or whatever images some may get style mixed and others not.

        # This mixing variable generates a list of 32 (if batch_size=32) True of False variables which determine whether each
        # sample in the batch undergoes style micing
        mixing = torch.rand(batch_size, device=z.device) < 0.9
        
        # Generate z2 and w2 for samples that undergo style mixing
        z2 = torch.randn_like(z)
        w2 = self.g_mapping(z2)
        
        # Generate crossover points for each sample. layers start at 1 and go up to 7
        crossover_points = torch.randint(1, 7, (batch_size,), device=z.device)
        
        # Initialize a list to hold the styles for each layer
        styles = []
        
        # For each layer, select w or w2 based on crossover points
        for layer_idx in range(7):  # 7 layers for a 256x256 network, if you increase resolution increase this too
            # Create a mask for samples that should use w2 at this layer
            use_w2 = mixing & (crossover_points <= layer_idx)
            # Select w or w2 for each sample in this layer
            style = torch.where(use_w2.unsqueeze(1), w2, w)
            styles.append(style)
        
        # Apply styles in generator blocks
        out = self.block_4x4(styles[0])
        out_4 = self.to_rgb_4(out)
        out_4 = self.upsample(out_4)
        
        out = self.block_8x8(styles[1], out)
        out_8 = self.to_rgb_8(out)
        # We also reset the variance here as we did in d_style_block
        out_8 += out_4 * (1 / np.sqrt(2))
        out_8 = self.upsample(out_8)
        
        out = self.block_16x16(styles[2], out)
        out_16 = self.to_rgb_16(out)
        out_16 += out_8 * (1 / np.sqrt(2))
        out_16 = self.upsample(out_16)
        
        out = self.block_32x32(styles[3], out)
        out_32 = self.to_rgb_32(out)
        out_32 += out_16 * (1 / np.sqrt(2))
        out_32 = self.upsample(out_32)
        
        out = self.block_64x64(styles[4], out)
        out_64 = self.to_rgb_64(out)
        out_64 += out_32 * (1 / np.sqrt(2))
        out_64 = self.upsample(out_64)

        out = self.block_128x128(styles[5], out)            
        out_128 = self.to_rgb_128(out)
        out_128 += out_64 * (1 / np.sqrt(2))
        out_128 = self.upsample(out_128)

        out = self.block_256x256(styles[6], out)
        out_256 = self.to_rgb_128(out)
        out_256 += out_128 * (1 / np.sqrt(2))

        # Finally, for one of the regularisation's we require the latents used in this pass of the network so we can return those too.
        if return_latents:
            return out_256, styles[6]
        else:
            return out_256

Discriminator

Note, that the residual connections are implemented in the d_style_block rather than in the network. Essentially we downsample the output of the previous layer and store it, we then pass the output of the previous layer through the network as usual and then we sum the two tensors.

# The D model is basically the same as in StyleGAN. Except observe how we have removed progressive growing, the out variable is passed
# through the entire network. We do not have any intermediary outputs, we go straight from 256x256 passing through all the layers
# right from the beginning of training. There are no if statements which control which layers to activate (or the num_layer variable
# we came up with last time)
class Discriminator(nn.Module):
    def __init__(self, out_c=512):
        super().__init__()

        # Similar story to G for this fmaps variable. It's great to have official code repos available which show things like this,
        # but if there arent any and you are implementing something rememeber to never give up. There are many tricks like this one
        # in play and it is our job to discover them!
        fmaps = 0.5
        out_c = int(out_c * fmaps)

        self.from_rgb = EqualLRConv2d(3, out_c//4, 1)

        self.block_256x256 = d_style_block(out_c//4, out_c//4, 3, 1)
        self.block_128x128 = d_style_block(out_c//4, out_c//2, 3, 1)
        self.block_64x64 = d_style_block(out_c//2, out_c, 3, 1)
        self.block_32x32 =  d_style_block(out_c, out_c, 3, 1)
        self.block_16x16 = d_style_block(out_c, out_c, 3, 1)
        self.block_8x8 = d_style_block(out_c, out_c, 3, 1)
        self.block_4x4 = d_style_block(out_c, out_c, 3, 1, 4, 0, mbatch=True)

        self.linear = EqualLRLinear(out_c, 1)

    def forward(self, x):
        
        out = self.from_rgb(x)

        out = self.block_256x256(out)
        out = self.block_128x128(out)
        out = self.block_64x64(out)
        out = self.block_32x32(out)
        out = self.block_16x16(out)
        out = self.block_8x8(out)
        out = self.block_4x4(out)

        out = out.view(out.size(0), -1)
        out = self.linear(out)

        return out

Regularisation in the StyleGAN2

StyleGAN2 employs 2 types of regularisation, namely Perceptual Path Length (PPL) and R1 regularisation. Let’s start with R1.

R1 Regularisation

The authors call this lazy regularisation, and it is applied to the loss function for D. A key idea from StyleGAN2 is that this operation can be applied less frequently than at every single iteration, we apply it after every 16 iterations. The ability to run it less frequently is good as it’s a heavy computation and adds time to our training. It is heavy because it requires us to compute the gradients of the D network. You will notice the slowdown in training , if you run the .py training file the pbar stops every 16 iterations to run this operation. The regularisation term originates from the https://arxiv.org/abs/1801.04406 paper (let’s call its the R1 paper for simplicity). It is seen as

\[ R_1(\psi) := \frac{\gamma}{2} \mathbb{E}_{p_D(x)} \left[\|\nabla D_\psi(x)\|^2\right] \tag{3}\]

The topic of regularisation in GANs is very interesting and the R1 paper breaks it down, it’s worthwhile to have a discussion on the paper before we continue. The premise is that (as we know) GANs are extremely hard to train and often times they do not converge (in theory the D and G should converge to their Nash Equilibria). One of the reasons for this is that GANs may overfit4. I think this paper is the reason Keras et al moved away from WGAN-GP (the loss function), it is stated that the WGAN and WGAN-GP scheme do not lead to convergence whereas the R1 regularized loss does. Another point is that of GAN instability, let’s think about the see-saw method of the GAN game. When the G is far away from the true data distribution the D pushes the G towards the true distribution. Simultaneously, the D model becomes better at classifying between G samples and true data samples. As this occurs, the slope of the D increases and when G reaches it’s optimal point the high gradient of D pushes it away. Once again then the G moves away from the true data distribution and the D model adjusts. This back-and-forth continues on. This overview covers the key points, but I encourage you to read the full paper for more details. The method to counteract is this the R1 regularisation! It counteracts the behaviour by applying a gradient penalty which stops the D model from deviating from the Nash Equilibirium. In the paper there are two functions for the R1 regularisation term. We use the first of the two, it makes use of real data and real predictions. The code implementation will follow after we discuss PPL regularisation :)

Perceptual Path Length (PPL) Regularisation

This regularisation is linked to the PPL metric, which measures the quality of the latent space captured by \(w\), and forces the G to favor “smooth” mappings. The authors observed a link between decreased the PPL metric and image subjective quality so they decided to create the PPL regularistion in order to push the PPL scores down (thus increasing image quality). The hypothesis is as follows: during training the D penalizes broken images (the worst generated images), the most direct way for G to improve is to stretch the region of the latent space which gives good images.

Implementing the Two Regularisers

Now, let’s write some code! Note, that R1 is run every 16 iterations and PPL every 4 (though the paper states it should be run every 8, their code repo they run it every 45 xD)

In the training code I have bits and pieces of the regulariser code spread around, I will combine them all here for illustrative purposes. This will not be the exact order in which you wil find the files in the python code.

# This function implements equation 3 above. It does not implement the whole thing however, observe we have no gamma/2
# that step is performed later

def r1_loss(real_pred, real_images):
    with torch.set_grad_enabled(True):        
        # Get gradients with respect to inputs only
        # This part implements the part in the square brackets delta (The big triangle) refers to the gradients
        grad_real, = torch.autograd.grad(
            outputs=real_pred.sum(),
            inputs=real_images,
            grad_outputs=torch.ones([], device=real_pred.device),
            create_graph=True,
            only_inputs=True  # Only compute gradients for inputs, not weights
        )
    # We sum over the axis 1 because we are operating on minibatches, so we need to sum their effect to get a scalar
    grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()

    return grad_penalty

# This code is used later in the following way

d_reg_freq = 16

# We run it every 16 minibatches
d_r1_regularise = i % d_reg_freq == 0
if d_r1_regularise:
    real_imgs.requires_grad_(True)

    real_preds = d(real_imgs)
    r1_loss_val = r1_loss(real_preds, real_imgs)

    d.zero_grad()

    NOTE: gamma is a key parameter in training of this model. It's quite difficult to tune and the authors have created a 
    # heuristic formula to calculate if for your respective settings: gamma = 0.0002 * (resolution ** 2) / batch_size
    gamma = 0.8192 # from https://github.com/NVlabs/stylegan3/blob/c233a919a6faee6e36a316ddd4eddababad1adf9/docs/configs.md
    # Calculate the other part of equation 3
    # The purpose of 0 * real_preds[0] is to ensure this operation is included in the PyTorch computational graph. If we do not include 
    # it you may see some errors. It doesnt actually change the value as we just add 0 (a no-op)
    r1_reg = (gamma/2 * r1_loss_val * d_reg_freq + 0 * real_preds[0])
    r1_reg.backward()

    d_optimizer.step()
# I admit I was not fully able to implement the PPL regulariser, I made use of the one provided by Rosinality
# Credit: https://github.com/rosinality/stylegan2-pytorch/blob/master/train.py#L87
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
    noise = torch.randn_like(fake_img) / math.sqrt(
        fake_img.shape[2] * fake_img.shape[3]
    )
    
    grad, = torch.autograd.grad(
        outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True,
        retain_graph=True, only_inputs=True
    )
    
    path_lengths = torch.sqrt(grad.pow(2).sum(1).mean(0))
    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
    path_penalty = (path_lengths - path_mean).pow(2).mean()

    del noise, grad. # To preserve memory
    torch.cuda.empty_cache()
    
    return path_penalty, path_mean.detach(), path_lengths

g_reg_freq = 4  # PPL is applied to the G model

# Similar logic to R1 when running PPL
g_ppl_regularise = i % g_reg_freq == 0
if g_ppl_regularise:
    z = torch.randn(real_size, latent_dim, device=device)
    gen_imgs, latents = g(z, return_latents=True)  # PPL calculation relies on latent vectors which produced imgs
    
    ppl_loss, mean_path_length, path_lengths = g_path_regularize(gen_imgs, latents, mean_path_length)

    g.zero_grad()

    # 2 is the weighting I use the same value that Rosinality used
    ppl_loss = 2 * g_reg_freq * ppl_loss
    
    ppl_loss.backward()
    g_optimizer.step()

So there we have it! We have all the bits and pieces ready to go, now it’s time to put them together.

Implementing the Training Loop

# The dataloader changes slightly in StyleGAN2
def get_dataloader(image_size, batch_size=8):
    # We only call this function once now, unlike in the StyleGAN where it would be required that we change resolution of the dataset when 
    # the network grew
    transform = transforms.Compose([
        #transforms.Resize((image_size, image_size)),  # No longer need to resize the images
        transforms.RandomHorizontalFlip(p=0.5),  # Add a random flip
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Reminder we use FFHQ, another face based dataset. There are more images and the diversity is increased.
    # An interesting note, the authors put a hell of a lot of work into building this dataset and the CelebA-HQ dataset.
    # This one of the reasons the StyleGAN series performs so well, the images are front lit human faces with good
    # colours and smooth transitions, they lend themselves very strongly to GAN modelling.
    dataset = ImageFolder(root='./ffhq256_imgs', transform=transform)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=1,
        drop_last=True,
        pin_memory=True,
        persistent_workers=True
    )

    return dataloader

# Create new checkpoint dir with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join('./checkpoints', f'run_{timestamp}')
checkpoint_dir = os.path.join(run_dir, f"checkpoint_{timestamp}")
sample_dir = os.path.join(run_dir, f"sample_{timestamp}")
# Create a file to log FID scores
fid_file = os.path.join(run_dir, 'fid.txt')

os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(sample_dir, exist_ok=True)

# Init models
g = Generator().to(device)
d = Discriminator().to(device)
g_running = Generator().to(device)

g_running.train(False)

fid = FrechetInceptionDistance(feature=2048).to(device)

mapping_params, other_params = get_params_with_lr(g)

# I include here the range of learning rates I tested
lr = 0.0025  
# Failed LRs - An idea is that with harder datasets a lower LR is needed. FFQH is quite an easy dataset, colours are smooth and images
# are well formed and often have a good front lit photo.
#lr = 0.00016  #https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/train.py#L157
#lr = 0.00008
#lr = 0.00005  # 0.00008 is better
# Let's reduce LR after checkpoint 215k which has FID of 113
#lr = 0.00001 - DIDNT DO MUCH LR DECAY NOT useful for me

mapping_lr = lr * 0.01

g_reg_freq = 4
d_reg_freq = 16

g_reg_adjustment = g_reg_freq / (g_reg_freq + 1)
d_reg_adjustment = d_reg_freq / (d_reg_freq + 1) 

g_optimizer = torch.optim.Adam([
    {'params': mapping_params, 'lr': mapping_lr},  # 0.01 * LR for mapping network
    {'params': other_params, 'lr': lr * g_reg_adjustment}  # Regular LR for other parts
], betas=(0.0 ** g_reg_adjustment, 0.99 ** g_reg_adjustment))
d_optimizer = torch.optim.Adam(d.parameters(), lr=lr*d_reg_adjustment, betas=(0.0 ** d_reg_adjustment, 0.99 ** d_reg_adjustment))

start_iter = 0

resume_checkpoint = False
if resume_checkpoint:
    if os.path.isfile(resume_checkpoint):
        print(f"=> loading checkpoint '{resume_checkpoint}'")
        checkpoint = torch.load(resume_checkpoint, weights_only=False)
        start_iter = checkpoint['iteration'] + 1
        g.load_state_dict(checkpoint['g_state_dict'])
        d.load_state_dict(checkpoint['d_state_dict'])
        g_running.load_state_dict(checkpoint['g_running_state_dict'])
        g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])

        print(f"=> loaded checkpoint '{resume_checkpoint}' ( iteration {start_iter})")
    else:
        print(f"=> no checkpoint found at '{resume_checkpoint}'")
else:
    print("Starting training from the beginning")

# Init EMA 
EMA(g_running, g, 0)
 
# We evaluate FID every 10k iterations - for more frequent updates
num_iters_for_eval = 10000

# We want to gen 70k fake images for FID calculation to match 30k real images
num_fake_images = 70000
latent_dim = 512  # Adjust based on your model's input size

# Define vars used within training loop
d_loss_val = None
g_loss_val = None
r1_loss_val = None

resolution = 256  # Resolution is always 256
batch_size = 16  # CHANGE to 32, 64 will run but is awfully slow - TRY 4???
total_iters = 25000 * 1000 // batch_size  # k imgs from paper
data_loader = get_dataloader(resolution, batch_size)
dataset = iter(data_loader)

# Init a progress bar
print(f'Training resolution: {resolution}x{resolution}, Batch size: {batch_size}')

if resume_checkpoint:
    remaining_iters = total_iters - start_iter  # Subtract what's been done
    # Create progress bar for remaining iterations but show correct absolute position
    pbar = tqdm(range(remaining_iters), initial=start_iter, total=total_iters)
else:
    pbar = tqdm(range(total_iters))

# This is used for our try and except loops
max_retries = 3

# Begin introducing layer phase
for i in pbar:
    requires_grad(g, False)
    requires_grad(d, True)

    try:
        real_imgs, label = next(dataset)
    except (OSError, StopIteration):
        # If we reach the end of the dataset, we reintialise the iterable
        # basically starting again
        dataset = iter(data_loader)
        real_imgs, label = next(dataset)

    # Train D
    real_size = real_imgs.size(0)
    real_imgs = real_imgs.to(device)
    label = label.to(device)
    real_preds = d(real_imgs)

    # Create gen images and gen preds
    z = torch.randn(real_size, latent_dim, device=device)
    gen_imgs = g(z)
    gen_preds = d(gen_imgs.detach())

    d_loss_val = d_loss(real_preds, gen_preds)

    d.zero_grad()
    d_loss_val.backward()
    d_optimizer.step()

    d_r1_regularise = i % d_reg_freq == 0
    if d_r1_regularise:
        real_imgs.requires_grad_(True)

        real_preds = d(real_imgs)
        r1_loss_val = r1_loss(real_preds, real_imgs)

        d.zero_grad()
        gamma = 0.8192
        r1_reg = ((gamma*0.5) * r1_loss_val * d_reg_freq + 0 * real_preds[0])
        r1_reg.backward()

        d_optimizer.step()
        torch.cuda.empty_cache()

    # Now lets train the Generator
    requires_grad(g, True)
    requires_grad(d, False)

    z = torch.randn(real_size, latent_dim, device=device)
    gen_imgs = g(z)
    gen_preds = d(gen_imgs)

    g_loss_val = g_loss(gen_preds)

    g.zero_grad()
    g_loss_val.backward()
    g_optimizer.step()

    # PPL reg, r1 only used on D
    # PPL a metric for gen images only
    g_ppl_regularise = i % g_reg_freq == 0
    if g_ppl_regularise:
        z = torch.randn(real_size, latent_dim, device=device)
        gen_imgs, latents = g(z, return_latents=True)  # PPL calculation relies on latent vectors which produced imgs
        
        ppl_loss, mean_path_length, path_lengths = g_path_regularize(gen_imgs, latents, mean_path_length)

        g.zero_grad()
        
        ppl_loss = 2 * g_reg_freq * ppl_loss
        
        ppl_loss.backward()
        g_optimizer.step()
        torch.cuda.empty_cache()
        
    EMA(g_running, g, decay=0.999)
    
    if i > 0 and i % num_iters_for_eval == 0:
        sample_z = torch.randn(16, latent_dim, device=device)
        sample_imgs_EMA = g_running(sample_z)
        save_image(sample_imgs_EMA, f'{sample_dir}/sample__iter_{i}.png', nrow=4, normalize=True)
        print(f'G_running images images after iter: {i+start_iter}')

        # Added a try and except loop for the FID calculation, note this try and except cannot catch Fatal Python errors
        for attempt in range(max_retries):
           try:
               calculate_and_save_fid(i, data_loader, g_running, num_fake_images, batch_size, latent_dim, device, fid_file)
               break  # If successful, exit retry loop
           except Exception as e:
               print(f"FID calculation failed (attempt {attempt + 1}/{max_retries})")
               print(f"Error: {str(e)}")
               if attempt == max_retries - 1:
                   print("Maximum retries reached. Skipping FID calculation for this iteration.")
               else:
                   print("Retrying...")
                   torch.cuda.empty_cache()
                   time.sleep(5)

        torch.save({
            'g_state_dict': g.state_dict(),
            'g_running_state_dict': g_running.state_dict(),
            'd_state_dict': d.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
            'iteration': i,
            'mean_path_length': mean_path_length,
        }, f'{checkpoint_dir}/checkpoint_iter_{i}.pth')

with torch.no_grad():
    sample_z = torch.randn(16, latent_dim, device=device)
    sample_imgs = g(sample_z)
    sample_imgs_EMA = g_running(sample_z)
    print('G images')
    show_images(sample_imgs)
    print('G_running images')
    show_images(sample_imgs_EMA)

    calculate_and_save_fid('final', data_loader, g_running, num_fake_images, batch_size, latent_dim, device, fid_file)

# No need for stabilising period - we only have one loop system

torch.save({
    'g_state_dict': g.state_dict(),
    'g_running_state_dict': g_running.state_dict(),
    'd_state_dict': d.state_dict(),
    'g_optimizer_state_dict': g_optimizer.state_dict(),
    'd_optimizer_state_dict': d_optimizer.state_dict(),
    'iteration': i
}, f'{checkpoint_dir}/completed_checkpoint_g_and_EMA.pth')

Training - A Story of Time and Energy

Training this model correctly took a long time (I began the training around 3/4 months ago). It took so long because my training runs would break down a lot, I ran at least 100 different training runs. To that end I want to discuss some of the failure patterns, so you don’t make the same mistakes and you can realise failures earlier, saving you time and money. My current approach to training runs is this: run the model from a python script, keep a log of FID and wait. For StyleGAN it takes around 12 hours (on a single 3090) before you’ve had enough iterations to judge whether or not to continue the training run. At that point I will review FID and decide whether or not to stop the run. If I stop the run, I go into the code examine to see if anything is wrong and then start the run again. Other times, the runs will just crash out with a “segmentation fault” at this time I have no way to automatically rerun the code6. It’s an iterative process, I do a run look at results and edit code based on them, over and over again. So, I hope you can take some of my learnings and avoid this process.

Note, given that StyleGAN2 has many repos around (and that we are all still learning) I do use these other repos to review my code. It’s a massive help and I think it adds to the learning experience, one day we will cover a paper here where there is no reference code and we will put to use all we have learned!

Failure Patterns I Observed

  • At the beginning of training I observed that FID starts at around 300-400, the quickest failure pattern can happen in the first 20k iterations. FID will hover at 400 then start to increase and fluctuate wildly. STOP training if you see this, it will not go anywhere. I found that reducing LR helped to mitigate this one, however it was not a full solution.
  • Another failure pattern began after reducing the LR, previously I used the same mapping network from my StyleGAN implementation. That has hidden layers of size 256 in and 256 out features. Training would often proceed well until around iteration 30k, at which point the loss would explode and decreasing LR further would cause the same behaviour but with smaller explosions each time. The fix was to make the mapping network have 512 in and 512 out features. I posit the increase in layer size allowing the latent space of \(w\) to capture more information which is needed for the more diverse/complex FFHQ images.
  • The idea to reduce the LR comes from the fact that in the paper they train with 8 GPUs and an LR of 0.0025. LR scales linearly, there is a rule shown in Figure 6. For us it goes like this: 8 gpus with a minibatch size of 64 (from the StyleGAN2-ada repo), this mean there is a total batch size of 512 (8 * 64). So to get our LR you do, their_batch_size / our_batch_size -> 512/32 = 16. So k=16, which means to get our LR we do 0.0025 / 16 which gives LR = 0.00015625. This is the LR that I settle with. EDIT This is actually incorrect. In the StyleGAN2 paper they use a batch size of 32 which is split between 8 gpus, so 4 images per GPU. Whereas, I had thought it was 32 on each GPU (so 32x8 or 64x8 in the StyleGAN2 ADA paper8]). Still the LR scaling rule is useful for other papers so I leave it here for you.
  • So far with mapping network with 512x512 hidden layers and LR=0.00015625, I can get to around 80 FID score. Which is still very bad, when training with this setup FID reduces pretty nicely until around 100k iteration, then it hovers around 77->85ish. This is another failure pattern you should be aware of.

Figure 6 - Learning Rate Scaling Rule7

A Critical Discovery, one bug led to a breakthrough

As I write my final training run is ongoing. I figured out what the issue was and why my training runs failed. In doing so I have encountered more issues. The reason all my training runs previously failed was because of a bug introduced on my part, I take this as a learning to double check all my code.

The way I uncovered this bug was because I had reached a point where training proceeded for around 350k iters and started crashing. At 350k I would also see a pattern in the FID, it would decrease steadily until this point and then from 350k to 400kish it would converge on a score of around 40. The bug in question was segmentation faults, which aren’t very straightforward to resolve. In my attempts to rectify these, I began examining my code closely and came across the following:

Figure 7 - A silly mistake which caused a lot of pain

I was not using the modulator. At this point I realised why my training failed so hard, I made the change to the Conv2d_mod operation and voila my training run now manages to get down to an FID of 11 at iteration 340k. (It remains to be seen how low it will go, but it is a vast improvement).

Some of my FID patterns

I provide some of the FID schemes, so you can recognise a bad run. I also provide in the code repo my final FID scheme for the successful model. To be honest, I’m not sure of the exact params each time I ran these models given I did so many training runs, but I hope to provide some examples so that you can see what didnt work for me (next time I will implement better logging and we can explore exactly why each one breaks down).

In the following scrollable output, you will see a few training runs which cover the different break downs I saw.


Here, the training appears to go well. But the GPU, crashed at iter 180k. I restarted training after that from a checkpoint and you see performance just fluctuates, reaching a minimum of 115. To me this is a failed run, as it fails to improve even after 200k iteration (bare in mind this took about 3/4 days).

Iteration 5000: 372.2608947753906

Iteration 10000: 273.1268615722656

Iteration 15000: 260.4564208984375

Iteration 20000: 223.1716766357422

Iteration 25000: 219.0269775390625

Iteration 30000: 205.28900146484375

Iteration 35000: 188.085205078125

Iteration 40000: 177.87722778320312

Iteration 45000: 200.12591552734375

Iteration 50000: 177.63365173339844

Iteration 55000: 175.5360107421875

Iteration 60000: 175.09141540527344

Iteration 65000: 178.38917541503906

Iteration 70000: 170.81842041015625

Iteration 75000: 160.3494415283203

Iteration 80000: 158.06622314453125

Iteration 85000: 159.43179321289062

Iteration 90000: 154.20611572265625

Iteration 95000: 158.0677947998047

Iteration 100000: 143.0650634765625

Iteration 105000: 143.35877990722656

Iteration 110000: 147.93089294433594

Iteration 115000: 128.8855438232422

Iteration 120000: 145.07362365722656

Iteration 125000: 142.61892700195312

Iteration 130000: 129.04664611816406

Iteration 135000: 127.5386734008789

Iteration 140000: 144.24447631835938

Iteration 145000: 145.9493408203125

Iteration 150000: 131.70626831054688

Iteration 155000: 122.86494445800781

Iteration 160000: 133.6108856201172

Iteration 165000: 121.39984893798828

Iteration 170000: 125.3056411743164

Iteration 175000: 133.8019256591797

Iteration 180000: 120.23395538330078

Iteration 185000: 122.73983764648438

Iteration 190000: 123.81011962890625

Iteration 195000: 145.60452270507812

Iteration 200000: 131.8021240234375

Iteration 205000: 131.51727294921875

Iteration 210000: 132.5868682861328

Iteration 215000: 115.08854675292969

Iteration 220000: 119.64303588867188

Iteration 225000: 117.72576141357422

Iteration 230000: 128.736328125

Iteration 235000: 146.3714599609375

Iteration 240000: 134.3799285888672

Iteration 245000: 127.26163482666016

Iteration 250000: 118.39510345458984

Iteration 255000: 133.1106719970703


Another which seemingly starts off well, then fluctuates. We got to a minimum FID of around 94, which is still pretty bad. At this FID images look horrible still.

Iteration 5000: 375.0341491699219

Iteration 10000: 266.1919250488281

Iteration 15000: 225.51492309570312

Iteration 20000: 209.35122680664062

Iteration 25000: 169.07830810546875

Iteration 30000: 162.7826690673828

Iteration 35000: 158.3661346435547

Iteration 40000: 142.1690216064453

Iteration 45000: 138.27444458007812

Iteration 50000: 137.2218780517578

Iteration 55000: 135.9760284423828

Iteration 60000: 119.19623565673828

Iteration 65000: 119.33064270019531

Iteration 70000: 113.0389175415039

Iteration 75000: 108.53450775146484

Iteration 80000: 111.82980346679688

Iteration 85000: 109.79493713378906

Iteration 90000: 104.31437683105469

Iteration 95000: 110.7328872680664

Iteration 100000: 122.6935043334961

Iteration 105000: 110.53292083740234

Iteration 110000: 105.91259765625

Iteration 115000: 98.43663787841797

Iteration 120000: 106.87227630615234

Iteration 125000: 100.67845916748047

Iteration 130000: 96.66962432861328

Iteration 135000: 94.45530700683594

Iteration 140000: 100.49161529541016

Iteration 145000: 109.65937042236328

Iteration 150000: 95.66574096679688

Iteration 155000: 106.8659896850586

Iteration 160000: 108.08671569824219


This run performed quite well, the FID drops a lot quicker than previous runs. Here is when I set LR=0.00015625. However, once again it just fluctuates after iter 70k.

Iteration 5000: 400.09442138671875

Iteration 10000: 224.55154418945312

Iteration 15000: 199.23419189453125

Iteration 20000: 166.6976776123047

Iteration 25000: 133.10545349121094

Iteration 30000: 122.34843444824219

Iteration 35000: 116.79188537597656

Iteration 40000: 106.47919464111328

Iteration 45000: 103.15625

Iteration 50000: 97.2866439819336

Iteration 55000: 88.9021987915039

Iteration 60000: 92.6580810546875

Iteration 65000: 87.99129486083984

Iteration 70000: 86.91797637939453

Iteration 75000: 82.1233139038086

Iteration 80000: 82.40524291992188

Iteration 85000: 77.69095611572266

Iteration 90000: 76.31440734863281

Iteration 95000: 76.78862762451172

Iteration 100000: 82.9169692993164

Iteration 105000: 81.16597747802734

Iteration 110000: 78.42939758300781

Iteration 115000: 79.91007232666016


I don’t have any saved FID schemes for the early failures I described. But I will create an example one so you know what to look out for. This might not be exactly what you see but the pattern will be close enough to recognise from this example.

Iteration 5000: 400.09442138671875

Iteration 10000: 330.039423

Iteration 15000: 302.231324

Iteration 20000: 299.209321

Iteration 25000: 325.99010

Iteration 30000: 320.392034

Iteration 35000: 340.9304

Iteration 40000: 396.0203109

Iteration 45000: 454.10394

Iteration 50000: 424.20319


I hope you find these useful!

SegFaults and Illegal Instructions

So once I had fixed my major mishap, training ran smoothly… until it didnt. I faced some very perplexing issues in areas of the code I had little control over. Frankly, I still am not sure why these errors cropped up. But, I include the python stack traces here in case you ever face any similar issues. My takeaway from this is that I need to make my training loops more robust, I did so by adding try and except loops with max retries (I added this to the calculate_and_save_fid function). The try and except is not a method to handle seg faults though, for those our checkpointing and reloading structure needs to improve.

Note The paths and hex numbers may look a little funny as I removed specific system details of mine.


The first here is an SegFault in the calculate_and_save_fid

G_running images images after iter: 360000

Fatal Python error: Segmentation fault Generating images: 38%|███▊ | 1669/4375 [01:51<03:01, 14.95it/s]

Thread 0x00000001 (most recent call first):

Thread 0x00000002 (most recent call first): File “/path/to/python/lib/python3.11/threading.py”, line 331 in wait File “/path/to/python/lib/python3.11/threading.py”, line 629 in wait File “/path/to/python/lib/python3.11/site-packages/tqdm/_monitor.py”, line 60 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000003 (most recent call first): File “/path/to/python/lib/python3.11/threading.py”, line 327 in wait File “/path/to/python/lib/python3.11/multiprocessing/queues.py”, line 231 in _feed File “/path/to/python/lib/python3.11/threading.py”, line 982 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000004 (most recent call first): File “/path/to/python/lib/python3.11/selectors.py”, line 415 in select File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 947 in wait File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 440 in _poll File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 257 in poll File “/path/to/python/lib/python3.11/multiprocessing/queues.py”, line 113 in get File “/path/to/python/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py”, line 32 in do_one_step File “/path/to/python/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py”, line 55 in pinmemory_loop File “/path/to/python/lib/python3.11/threading.py”, line 982 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Current thread 0x00000005 (most recent call first): File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/conv.py”, line 454 in convforward File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/conv.py”, line 458 in forward File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1562 in callimpl File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1553 in wrappedcall_impl File “/path/to/python/lib/python3.11/site-packages/torch_fidelity/feature_extractor_inceptionv3.py”, line 208 in forward File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1562 in callimpl File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1553 in wrappedcall_impl File “/path/to/python/lib/python3.11/site-packages/torch_fidelity/feature_extractor_inceptionv3.py”, line 232 in forward File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1562 in callimpl File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1553 in wrappedcall_impl File “/path/to/python/lib/python3.11/site-packages/torchmetrics/image/fid.py”, line 111 in torchfidelity_forward File “/path/to/python/lib/python3.11/site-packages/torchmetrics/image/fid.py”, line 155 in forward File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1562 in callimpl File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1553 in wrappedcall_impl File “/path/to/python/lib/python3.11/site-packages/torchmetrics/image/fid.py”, line 365 in update File “/path/to/python/lib/python3.11/site-packages/torchmetrics/metric.py”, line 483 in wrapped_func File “/path/to/project/StyleGAN2/StyleGAN2.py”, line 898 in add_fake_images File “/path/to/project/StyleGAN2/StyleGAN2.py”, line 923 in calculate_and_save_fid File “/path/to/project/StyleGAN2/StyleGAN2.py”, line 1139 in

[1]+ Segmentation fault (core dumped) nohup python3 -X faulthandler -u StyleGAN2.py > output.log 2>&1


The second is caused by an illegal instruction again in the calculate_and_save_fid function. An illegal instruction I think is due to a CPU instruction being ran on the GPU or vice versa.

Fatal Python error: Illegal instruction Generating images: 29%|██▉ | 1259/4375 [01:26<03:34, 14.54it/s]

Thread 0x00000001 (most recent call first):

Thread 0x00000002 (most recent call first): File “/path/to/python/lib/python3.11/threading.py”, line 331 in wait File “/path/to/python/lib/python3.11/threading.py”, line 629 in wait File “/path/to/python/lib/python3.11/site-packages/tqdm/_monitor.py”, line 60 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000003 (most recent call first): File “/path/to/python/lib/python3.11/threading.py”, line 327 in wait File “/path/to/python/lib/python3.11/multiprocessing/queues.py”, line 231 in _feed File “/path/to/python/lib/python3.11/threading.py”, line 982 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000004 (most recent call first): File “/path/to/python/lib/python3.11/selectors.py”, line 415 in select File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 947 in wait File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 440 in _poll File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 257 in poll File “/path/to/python/lib/python3.11/multiprocessing/queues.py”, line 113 in get File “/path/to/python/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py”, line 32 in do_one_step File “/path/to/python/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py”, line 55 in pinmemory_loop File “/path/to/python/lib/python3.11/threading.py”, line 982 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Current thread 0x00000005 (most recent call first): File “/path/to/python/lib/python3.11/site-packages/torch_fidelity/feature_extractor_inceptionv3.py”, line 209 in forward File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1562 in callimpl File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1553 in wrappedcall_impl File “/path/to/python/lib/python3.11/site-packages/torch_fidelity/feature_extractor_inceptionv3.py”, line 234 in forward File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1562 in callimpl File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1553 in wrappedcall_impl File “/path/to/python/lib/python3.11/site-packages/torchmetrics/image/fid.py”, line 113 in torchfidelity_forward File “/path/to/python/lib/python3.11/site-packages/torchmetrics/image/fid.py”, line 155 in forward File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1562 in callimpl File “/path/to/python/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1553 in wrappedcall_impl File “/path/to/python/lib/python3.11/site-packages/torchmetrics/image/fid.py”, line 365 in update File “/path/to/python/lib/python3.11/site-packages/torchmetrics/metric.py”, line 483 in wrapped_func File “/path/to/project/StyleGAN2/StyleGAN2.py”, line 896 in add_fake_images File “/path/to/project/StyleGAN2/StyleGAN2.py”, line 921 in calculate_and_save_fid File “/path/to/project/StyleGAN2/StyleGAN2.py”, line 1137 in module


The following error cropped up a number of times at different iterations.

FID score for iteration 20000: 64.7917022705078175 [02:54<00:00, 25.46it/s] 1%|▏ | 20668/1562500 [3:51:10<168:03:37, 2.55it/s]Fatal Python error: Segmentation fault

Current thread 0x00000001 (most recent call first):

Thread 0x00000002 (most recent call first): File “/path/to/python/lib/python3.11/threading.py”, line 331 in wait File “/path/to/python/lib/python3.11/threading.py”, line 629 in wait File “/path/to/python/lib/python3.11/site-packages/tqdm/_monitor.py”, line 60 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000003 (most recent call first): File “/path/to/python/lib/python3.11/threading.py”, line 327 in wait File “/path/to/python/lib/python3.11/multiprocessing/queues.py”, line 231 in _feed File “/path/to/python/lib/python3.11/threading.py”, line 982 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000004 (most recent call first): File “/path/to/python/lib/python3.11/threading.py”, line 327 in wait File “/path/to/python/lib/python3.11/multiprocessing/queues.py”, line 231 in _feed File “/path/to/python/lib/python3.11/threading.py”, line 982 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000005 (most recent call first): File “/path/to/python/lib/python3.11/selectors.py”, line 415 in select File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 947 in wait File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 440 in _poll File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 257 in poll File “/path/to/python/lib/python3.11/multiprocessing/queues.py”, line 113 in get File “/path/to/python/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py”, line 32 in do_one_step File “/path/to/python/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py”, line 55 in pinmemory_loop File “/path/to/python/lib/python3.11/threading.py”, line 982 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000006 (most recent call first): File “/path/to/python/lib/python3.11/site-packages/torch/autograd/graph.py”, line 768 in enginerun_backward File “/path/to/python/lib/python3.11/site-packages/torch/autograd/init.py”, line 289 in backward File “/path/to/python/lib/python3.11/site-packages/torch/_tensor.py”, line 521 in backward File “/path/to/project/StyleGAN2/StyleGAN2.py”, line 1093 in module


Here’s a real weird one, this error is thrown by the following code: torch.cuda.empty_cache(). The whole point I put this in was to stabilise training not cause more, but alas I do not know why it causes an illegal instruction

38%|███▊ | 591068/1562500 [29:32:52<107:33:34, 2.51it/s]Fatal Python error: Illegal instruction

Thread 0x00000001 (most recent call first): Thread 0x00000002 (most recent call first): File “/path/to/python/lib/python3.11/threading.py”, line 331 in wait File “/path/to/python/lib/python3.11/threading.py”, line 629 in wait File “/path/to/python/lib/python3.11/site-packages/tqdm/_monitor.py”, line 60 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000003 (most recent call first): File “/path/to/python/lib/python3.11/threading.py”, line 327 in wait File “/path/to/python/lib/python3.11/multiprocessing/queues.py”, line 231 in _feed File “/path/to/python/lib/python3.11/threading.py”, line 982 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Thread 0x00000004 (most recent call first): File “/path/to/python/lib/python3.11/selectors.py”, line 415 in select File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 947 in wait File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 440 in _poll File “/path/to/python/lib/python3.11/multiprocessing/connection.py”, line 257 in poll File “/path/to/python/lib/python3.11/multiprocessing/queues.py”, line 113 in get File “/path/to/python/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py”, line 32 in do_one_step File “/path/to/python/lib/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py”, line 55 in pinmemory_loop File “/path/to/python/lib/python3.11/threading.py”, line 982 in run File “/path/to/python/lib/python3.11/threading.py”, line 1045 in bootstrapinner File “/path/to/python/lib/python3.11/threading.py”, line 1002 in _bootstrap

Current thread 0x00000005 (most recent call first): File “/path/to/python/lib/python3.11/site-packages/torch/cuda/memory.py”, line 170 in empty_cache File “/path/to/project/StyleGAN2/StyleGAN2.py”, line 1129 in module


The tendency of such issues to appear randomly is what leads me to think that if we were to retry the code would run as intended and it often does. So, as mentioned we need to develop a better method to handle retries and manage the training runs better. I will leave that for a later blog post or for you!

Some other steps I took to reduce frequency of these errors was setting num_workers = 1 in dataloader (trying to prevent simultaneous memory access), adding more torch.cuda.empty_cache() to ensure no out of bounds memory access and adding a torch.cuda.synchronize() before calling save_and_calc_fid.

The Successful Training Run

Running the training code, I run the training code with the following command:

  • nohup python3 -X faulthandler -u StyleGAN2.py > lr_0025_batch_16_g_every_4.log 2>&19

The above command will run the StyleGAN2.py code in the background and you can view training progress in the .log file.

Let’s load in our final FID scheme for the successful training run and analyse it a little. You can view the final FID txt file in the repository for this code too, check the link in the sidebar ;)

Code
# Read FID scores from file
def read_fid_scores(filename):
   iterations = []
   scores = []
   with open(filename, 'r') as f:
       for line in f:
           if line.startswith('Iteration'):
               iter_num = int(line.split(':')[0].split()[1])
               score = float(line.split(':')[1].strip())
               iterations.append(iter_num)
               scores.append(score)
   return iterations, scores
    
# Plot FID scores
def plot_fid(iterations, scores):
   plt.figure(figsize=(10,6))
   plt.plot(iterations, scores)
   
   min_score = min(scores)
   min_iter = iterations[scores.index(min_score)]
   
   plt.plot(min_iter, min_score, 'ro')
   plt.annotate(f'Min FID: {min_score:.2f}\nIteration: {min_iter}',
               xy=(min_iter, min_score),
               xytext=(10, 10),
               textcoords='offset points',
               bbox=dict(facecolor='white', alpha=0.8))
   
   plt.xlabel('Iteration')
   plt.ylabel('FID Score')
   plt.title('FID Score vs Training Iteration')
   plt.grid(True)
   return plt

iterations, scores = read_fid_scores('./StyleGAN2_fid.txt')
plot_fid(iterations, scores)

Ain’t that a pretty loss curve, it’s exactly the sort of shape we want! There is a steep decrease to around FID 20, which occurs roughly at iteration 100000 (this took 20 hours of training to achieve). Then note how much longer it takes to reduce FID to the final score of 8.06 at iteration 1010k. In total this whole run probably took around 200-250 hours to complete, and the graph shows that most of training is spent on eeking out small gains. The model converges around the FID score of 8, it oscillated in this range iteration 600k until I stopped training. Even this lowest score of 8.06 I think is due to randomness, continued training after achieving this score didnt lead to any clear improvements.

Getting here required a lot of manual intervention which I hope to reduce in my next post, I had to restart training 4 times manually editing the path to checkpoints and such. If you have any queries while reading or running the code please reach out to me and I’d be happy to help (my email is in the footer).

So what do our images look like?

To illustrate the model’s capabilities I will show you 16 randomly generated images from different stages throughout the training. Stating from iteration 10k up to iteration 950k which has the lowest FID. To make reference to the images think of it as a 4x4 grid with the top left-most image being at position [0,0].

Iteration 10k - FID 111.28 - The model picks up on high level details quite quickly, we see the emergence of faces and the sorts of colours/features a face should have albeit with big distortions

Iteration 50k - FID 32.16 - The model is improving, these images are starting to take shape. Facial features and accessories such as glasses are better formed now

Iteration 100k - FID 19.60 - The 20 FID threshold, note 20 FID has no actual meaning it’s just something I empirically noticed in this training run. The training slow down occurs past this point. Also, look at these samples they’re hardly great, those at 50k could be seen as better. But, the defining factor I think is the colours in the images, at 50k there are a lot of distortions which don’t occur so much at 100k iterations

Iteration 150k - FID 15.86 - I think our samples are beginning to look quite good. Take for example the image at position [2,2] that dude looks pretty real

Iteration 400k - FID 9.94 - So it took us 250k iterations to get from 15.86 to >10 FID. I hope this highlights the slowdown and what role the later stage of training plays. The images now possess most of the features but they just aren’t realistic yet.

Iteration 950k - FID 8.06 - The best FID score we got. These faces look great, albeit with some issues. I think if we were to handpick our images we’d get some believable faces!

Thats it! We’ve trained a StyleGAN2 model and it’s ready to use. Here’s a link to the best iterations weights: https://huggingface.co/YM2132/StyleGAN2/tree/main. To ensure convergence I continued training to 1200k iterations, and the FID kept oscillating between 8-9 indicating the model has settled. Perhaps some LR adjustments could reduce the FID score further.

Lastly, a closing remark about the behaviour of this GAN displayed in training. We can think of it as the model learning the basic features at the start and then spending the rest of the time finetuning the output. Where previously we enforced this behaviour, learning high level details before fine grained details, the model now does it itself! This behaviour is truly remarkable, the optimisation process of the StyleGAN2 results in the model enforcing progressive growing to create the best images possible. I leave this with you to ponder and I hope you enjoyed the path to StyleGAN2


1 This is an important aspect of deep learning research in general. The authors most likely came to this conclusion not through some sort of theoretical reasoning but rather they performed ablation studies. An ablation study is when you remove/change one aspect of a system and observe the outcome. These are so crucial in deep learning research as it’s very hard to understand the why (i.e. why does AdaIN cause water droplets) but it’s easier to observe what effect AdaIN has. In the paper they state when it is removed this water droplet effect doesnt occur, hence ablation studies FTW! Some more info on ablation studies https://x.com/fchollet/status/1831029432653599226

2 Now this is a little wordy, and I do not fully understand it myself yet. But my initial idea is the G model creates pixel intensity spikes as a result of the AdaIN operation removing information related across features. Because each feature map is treated independently by AdaIN, if there is an issue in one feature map it can be propagated through the network.

3 Here is another cool part of deep learning. There are many ideas which work in different contexts and often when applied to other problems they can yield good results too. For example, here residual connections come from ResNets which were for computer vision tasks. I’d recommend you in your journey to try out such things, try different combinations of architectures/ideas and observe the results.

4 I got this information from this absolutely amazing blog post by Gwern: https://gwern.net/face. I’d suggest you read this from start to finish (along with his other posts), it is packed full of insights (some of which we discuss here) into GANs and training GANs. For this idea of overfitting, there is an argument that GANs overfit to the dataset. This is a better problem that underfitting as it shows we are learning from the data and to reduce overfitting we can employ regularisation. So thats where the idea of regularisation in GANs come in.

5 https://github.com/NVlabs/stylegan2/blob/master/training/training_loop.py#L121C1-L121C124

6 This begs the question how do large companies handle training runs? For me even this was quite expensive with the electricity costing around £3 a day. I wonder how do large training runs handle “segmentation faults”, GPU errors and the infra around training runs. Also, is there a method to know (or have an estimate) on how likely a run is to fail or not a priori? If you have any answers to these questions please reach out to me: yusufmohammad@live.com I’d love to know more.

7 The paper this comes from can be rule found at: https://arxiv.org/pdf/1706.02677

8 The info about batch size for StyleGAN2: https://github.com/rosinality/stylegan2-pytorch/issues/152

9 An explainer of what that command does (this is a very cool website) https://explainshell.com/explain?cmd=nohup+python3+-X+faulthandler+-u+StyleGAN2.py+%3E+lr_0025_batch_16_g_every_4.log+2%3E%261

10 In the StyleGAN3 readme there is a table of hyper parameters for different setups for StyleGAN2, you can use these to adjust your params to ensure the model fits on your GPU.