Since the last network visualisation post, I have spent time refining and perfecting the approach to producing the best images using the most straightforward code! This is partly because it forms part of my thesis and partly because I like to explore the internal representations of Convolution Neural Networks (CNNs).

In the previous post on visualisation, I highlighted Total Variation and Diversity as two losses we need to account for when producing images that maximally activate a kernel. It turns out that this isn’t really necessary, and we can produce images that are just as good by using transformation (translations, rotations and scaling) between rounds of optimisation. This was the first simplification. The second was that we sometimes had non-convergence in some of the images (they remained grey), which could have been due to the lack of ReLU redirection. ReLU redirection is necessary during early rounds of optimisation since the target kernel may not be active; therefore, no optimisation can be performed, leading to non-convergence (see this python file for more information). In addition, because the original code (Lucid and Lucent) used gradient descent rather than ascent (and the whole idea is to maximise the activation), there were some extra steps necessary to make sure the gradients were being traversed in the correct directions. To simplify this, I’ve changed the optimisation to focus on maximisation:

opt = torch.optim.Adam(params=[image_c.spectrum_mp], lr=lr, maximize=True)

and added a small amount of noise to the gradients to kickstart image generation if there are no gradients to follow:

def image_grad_hook(grad):
    grad = grad / (torch.norm(grad) + 1e-8)  # Normalizing to unit norm
    grad = grad + torch.normal(mean=0, std=1e-6, size=grad.shape).to(device)
    return grad

Apart from these changes, nearly everything is the same! Through testing, it also achieves results that are pretty close to Lucent, all within 330 lines of code!

While the code is specifically designed to work with GoogLeNet1 (it uses the TorchVision GoogLeNet transforms, and the undo_decorrelate function uses the colour correlation matrix specifically calculated for GoogLeNet), it seems to work really well for ResNet2!

The images are included below, starting with the GoogLeNet activation images from 4c.branch4.1.conv: Images of colourful structured patterns which activate GoogLeNet

These are the activation images from ResNet-18 layer3.1.conv2: Images of colourful structured patterns which activate ResNet-18

The full code to produce these visualisations is here:

# BatchNorm based visualisation!
import torch
import torchvision
from torchvision.transforms import v2
import numpy as np


class Activation(torch.nn.Module):
    """
    Define a new loss function for visualisation
    """
    def forward(self, activation):
        """
        Overrides the default forward behaviour of torch.nn.Module

        Parameters:
        - activation (torch.Tensor): The activation tensor after the network function has been applied to the image

        Returns:
        - (torch.Tensor): The mean of the activation tensor (avoiding the boarder to reduce artifacts)
        """
        filter_act = activation
        # filter_act = activation[2:-2, 2:-2] # <-- Adds the boarder
        return filter_act.mean()


class OptImage():
    """
    An image for optimisation which includes the colour-decorrelated, Fourier
    transformed image.
    Code from:
    https://github.com/greentfrapp/lucent/blob/dev/lucent/optvis/param/spatial.py
    and
    https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/spatial.py

    """
    def __init__(self, shape, stdev=0.01, decay=1, device='cpu'):
        # Create a colour decorrelated, Fourier transformed image
        self.batch, self.ch, self.h, self.w = shape
        freqs = self.rfft2d_freqs(self.h, self.w)
        init_val_size = (self.batch, self.ch) + freqs.shape + (2,) # 2 for the magntude and phase of FFT

        self.spectrum_mp = torch.randn(*init_val_size, dtype=torch.float32) * stdev # This is what we optimise!
        self.spectrum_mp = self.spectrum_mp.to(device)
        self.spectrum_mp.requires_grad = True # Really important part!

        self.scale = 1/np.maximum(freqs, 1/max(self.h, self.w)) ** decay
        self.scale = torch.tensor(self.scale).float()[None, None, ..., None]
        self.scale = self.scale.to(device)

        self.device = device


    # Directly from Lucid
    @staticmethod
    def rfft2d_freqs(h, w):
        """Computes 2D spectrum frequencies."""

        fy = np.fft.fftfreq(h)[:, None]
        # when we have an odd input dimension we need to keep one additional
        # frequency and later cut off 1 pixel
        if w % 2 == 1:
            fx = np.fft.fftfreq(w)[: w // 2 + 2]
        else:
            fx = np.fft.fftfreq(w)[: w // 2 + 1]
        return np.sqrt(fx * fx + fy * fy)

    def deprocess(self):
        # Transform colour-decorrelated, Fourier transformed image back to normal
        scaled_spectrum = self.scale*self.spectrum_mp

        if type(scaled_spectrum) is not torch.complex64:
            scaled_spectrum = torch.view_as_complex(scaled_spectrum)

        image = torch.fft.irfftn(scaled_spectrum, s=(self.h,self.w), norm='ortho')

        image = image[:self.batch, :self.ch, :self.h, :self.w]
        image = image / 4.0 # MAGIC NUMBER

        if self.ch == 3:
            # Only decorrelate colour if we have a 3 channel image
            image = OptImage.undo_decorrelate(image, self.device)

        assert image.max() <= 1. and image.min() >= 0., "Image broke bounds! Max: {}, Min: {}".format(image.max(), image.min())
        return image

    @staticmethod
    def undo_decorrelate(image, device):
        # Undo the colour decorrelation
        color_correlation_svd_sqrt = np.asarray(
            [[0.26, 0.09, 0.02],
             [0.27, 0.00, -0.05],
             [0.27, -0.09, 0.03]]).astype("float32")

        max_norm_svd_sqrt = np.max(np.linalg.norm(color_correlation_svd_sqrt, axis=0))
        color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt

        c_last_img = image.permute(0,2,3,1)
        c_last_img = torch.matmul(c_last_img, torch.tensor(color_correlation_normalized.T).to(device))
        image = c_last_img.permute(0,3,1,2)
        image = torch.sigmoid(image) # An important part of the decorrelation it seems!
        return image


def visualise(
        model,
        target,
        opt_img_shape,
        filters,
        iterations=256,
        lr=5e-2,
        opt_type='channel',
        device='cpu'):
    """
    Visualise the specified kernels in the target layer of the model

    Parameters:
    - model (torch.nn.Module): The model to visualise a layer of
    - target (str): The target layer to visualise
    - opt_img_shape ((int, int, int)): The shape of the visualised images (channel first)
    - filters ([int]): The list of kernels to visualise
    - iterations (int, optional): The number of optimisation iterations to run for (default is 256)
    - lr (float, optional):  The learning rate for image updates (default is 5e-2)
    - opt_type (str, optional): The type of optimisation (neuron, channel, layer/dream) (default is 'channel')
    - device (str, optional): The device to use for optimisation ('cpu' or 'cuda') (default is 'cpu')
    """

    # Set up result tracking
    best = {}
    for i in range(len(filters)):
        best[i] = {"loss": -np.inf, "image": None, "activation": None, "iter": None}

    # Set the model to evaluation mode - SUPER IMPORTANT
    model.eval()

    # DEBUG:
    output_losses = []
    ##################

    hook_list = []
    
    global activation
    activation = None
    def activation_hook(module, input, output):
        global activation
        activation = output

    hook = target.register_forward_hook(activation_hook)
    hook_list.append(hook)

    opt_img_shape = (len(filters),) + opt_img_shape # Add batch dimension
    image_c = OptImage(shape=opt_img_shape, device=device)


    ### Image Optimisation hook
    def image_grad_hook(grad):
        grad = grad / (torch.norm(grad) + 1e-8)  # Normalizing to unit norm
        grad = grad + torch.normal(mean=0, std=1e-6, size=grad.shape).to(device) # Experimental
        return grad
        
    img_hook = image_c.spectrum_mp.register_hook(image_grad_hook)
    hook_list.append(img_hook)
    ###########################
    

    # Define the custom loss function
    loss_fn = Activation()

    opt = torch.optim.Adam(params=[image_c.spectrum_mp], lr=lr, maximize=True)

    h, w = opt_img_shape[-2], opt_img_shape[-1]
    googlenet_transform = torchvision.models.GoogLeNet_Weights.IMAGENET1K_V1.transforms()
    viz_transform = torchvision.transforms.v2.Compose([
        torchvision.transforms.v2.Pad(padding=16, fill=(0.5)),
        torchvision.transforms.v2.RandomAffine(degrees=0, translate=(8/w, 8/h)),
        torchvision.transforms.v2.RandomAffine(degrees=0, scale=(0.95, 1.05)),
        torchvision.transforms.v2.RandomAffine(degrees=5),
        torchvision.transforms.v2.RandomAffine(degrees=0, translate=(4/w, 4/h)),
        torchvision.transforms.v2.CenterCrop(size=(h,w)),
        torchvision.transforms.v2.Resize(size=(256, 256), antialias=True),
        googlenet_transform,
    ])

    max_iterations = iterations
    for it in range(max_iterations):

        opt.zero_grad()


        transformed_image = viz_transform(image_c.deprocess())
        _ = model(transformed_image.to(device))

        # index 0 is the batch index I guess?
        if opt_type == 'layer' or opt_type == 'dream':
            act = activation[:, :, :, :] # Layer (DeepDream)
        elif opt_type == 'channel':
            act = activation[:, filters, :, :] # Channel
        elif opt_type == 'neuron':
            # Select the central neuron by default (TODO: Allow this to be overridden)
            nx, ny = activation.shape[2], activation.shape[3]
            act = activation[:, filters, nx//2, ny//2] # Neuron

        deproc_img = image_c.deprocess().to(device)

        # Calculate loss for each image we're optimising
        losses = []

        noise_indices = []

        for i in range(deproc_img.shape[0]): # Iterate over all images
            magnitude_spectrum = torch.abs(image_c.spectrum_mp[i])
            spatial_freq = torch.mean(magnitude_spectrum).item()

            loss = loss_fn(act[i][i]) #+ 1e-3 * div #+ 1e-4 * total_variation_loss(image_c.deprocess())#+ 1e-2*div
            losses.append(loss)
            # if variance < 0.05:
            if spatial_freq < 4.0:
                noise_indices.append(i)

            # Check if this is the best loss we've found for the image
            if loss > best[i]["loss"]:
                best[i]["loss"] = loss
                best[i]["image"] = deproc_img.cpu()[i].clone()
                best[i]["activation"] = act.detach()[i].cpu()
                best[i]["iter"] = it+1        

        # Stack all losses for quick backpropagation
        total_loss = torch.stack(losses).sum()
        total_loss.backward()
        
        #### Check the gradient
        total_norm = 0

        # Check model gradient norm
        for name, param in model.named_parameters():
            if param.grad != None:
                total_norm += param.grad.norm()**2
                param.grad.zero_()

        total_norm = total_norm**0.5
        if it == max_iterations-1:
            print("Iteration {} - Total norm: {}".format(it, total_norm))
        #######################

        torch.nn.utils.clip_grad_norm_(image_c.spectrum_mp, max_norm=0.1)
        opt.step()

        # DEBUG
        output_losses.append([l.detach().item() for l in losses])
        #######

    opt_images = []
    best_activations = []
    for i in range(len(filters)):
        opt_img = best[i]["image"].detach().squeeze().cpu()
        opt_img = opt_img.permute(1,2,0)
        opt_img = torch.clamp(opt_img, 0, 1)
        opt_images.append(opt_img)

        act = best[i]["activation"][i]
        best_activations.append(act)

    for h in hook_list:
        h.remove() # Remove the hook so subsequent runs don't use the previously registered hook!

    return best_activations, opt_images, output_losses


def batched_visualise(
        model,
        target,
        filters,
        opt_img_shape,
        iterations=30,
        lr=5e-2,
        batch_size=32,
        opt_type='channel',
        device='cpu'):
    
    all_activations, all_images, all_losses = [], [], []

    # Convert the target layer from string to torch layer
    target_layer = None
    for name, layer in model.named_modules():
        
        # Match the end of the strings since we could have some prefixes like `googlenet.`
        if name[len(name)-len(target):] == target:
            print("Found target layer: {} - {}".format(name, type(layer)))
            target_layer = layer
            # Get the maximum number of filters that can be extracted
            max_filters = layer.weight.shape[0]
            break
    else:
        print("Could not find layer: {} - exiting".format(target))

    # Check the number of kernels being visualised
    if filters == None:
        filters = range(max_filters)
    elif len(filters) > max_filters:
        print("Number of requested filters exceeds the amount available - setting to {}".format(max_filters))
        filters = range(max_filters)

    print("Activating {} filters".format(len(filters)))

    # Split the kernels into manageable batches
    num_splits = np.ceil(len(filters)/batch_size)
    filter_batches = np.array_split(np.asarray(filters), num_splits)
    # Visualise the batch
    for filter_batch in filter_batches:
        print("Processing batch: {}".format(filter_batch))
        acts, imgs, losses = visualise(
            model=model,
            target=target_layer,
            opt_img_shape=opt_img_shape,
            filters=filter_batch,
            iterations=iterations,
            lr=lr,
            opt_type=opt_type,
            device=device  
        )
        
        all_activations = all_activations + acts
        all_images = all_images + imgs
        all_losses.append(losses)

    all_activations = list(np.concatenate(all_activations))
    # all_images = list(np.concatenate(all_images))
    all_losses = np.concatenate(all_losses, axis=-1)
    
    return all_activations, all_images, all_losses