Style transfer with Neural Cellular Automata
Local Context, Global Results: NCA for Style Transfer
Introduction
The motivation behind this project is to gain a deeper understanding of neural cellular automata — a attractive topic tied to complex systems and emergent behavior observed in nature. These systems demonstrate how simple rules (and small models) can lead to complex outcomes without the need for a central architect or planner. In this project, we aim to explore the extent to which simple rules can be applied to style transfer using neural networks.
Conway’s Game of Life
A simple example of a cellular automaton is Conway’s Game of Life. To play this game, you need a two-dimensional grid where the cells reside. Each cell can exist in one of two states: alive or dead. Each cell interacts with its eight neighbors, and specific rules dictate what happens during these interactions. In the original Game of Life, the following occurs at each time step:
- Any live cell with fewer than two live neighbors dies, as if by underpopulation.
- Any live cell with two or three live neighbors survives to the next generation.
- Any live cell with more than three live neighbors dies, as if by overpopulation.
- Any dead cell with exactly three live neighbors becomes a live cell, as if by reproduction.
There is also a seed cell that determines the initial state of the cells. We can configure the grid with most cells dead or alive.
We can simplify the setup. Using an elementary cellular automaton, we only interact with three cells (to the left, right, and below). Here’s a neat visualization of some of these interactions:
At first glance, these simple rules might seem to lack any significant meaning. However, it turns out that depending on the rules you apply, you can generate vastly different results and patterns (see examples here: Cellular Automata).
In a rule element, we have three binary cells, which create 2³ = 8 possible states. Each rule consists of 8 elements, leading to a total of 2⁸ = 256 possible rules. While 256 rules may not seem like a large number, the differences between them are fascinating. For instance, Rule 30 generates a particularly intriguing pattern.
No one designed or planned this pattern; it emerged naturally from a set of simple rules. Interestingly, this pattern closely resembles the intricate design found on the Conus textile shell.
Some patterns are similar but feature different fractals — another interesting topic within the study of complex systems:
I hope I’ve demonstrated how simple rules can give rise to complex patterns. Now, let’s explore how we can apply this concept to neural networks.
Model
I’ve heavily drawn on an example from here. My version of the code is here. I refactored code a bit (using callbacks) and mode some other changes in model training pipeline. The model architecture I’m using is fairly straightforward; it consists of several stride-1 convolutional layers:
NeuralCA(
(w1): Conv2d(126, 90, kernel_size=(1, 1), stride=(1, 1))
(w2): Conv2d(90, 90, kernel_size=(1, 1), stride=(1, 1))
(w3): Conv2d(90, 15, kernel_size=(1, 1), stride=(1, 1))
)
The idea is to minimize the number of learnable parameters by incorporating neural cellular automata, essentially integrating cellular automata into a neural network. How do we add cellular automata? It’s simple: you use a filter (kernel) that considers the local aspects of the grid and apply it in a convolution. The heavy lifting is done by performing convolution per channel:
def perchannel_conv(x, filters):
'''filters: [filter_n, h, w]'''
b, ch, h, w = x.shape
y = x.reshape(b*ch, 1, h, w)
y = F.pad(y, [1, 1, 1, 1], 'circular')
y = F.conv2d(y, filters[:,None])
return y.reshape(b, -1, h, w)
This function takes a batch of images x as input, along with one or more filters. The main purpose of these filters is to perform feature or edge detection and to approximate discretized versions of partial differential equations (PDEs). Solving PDEs can be challenging, especially since many do not have closed-form solutions. Discretizing continuous space into a grid makes this process more manageable. Since the filters are not learnable parameters, the total parameter count of the model remains relatively small.
One addition I made involves the use of filters. The initial solution utilized only the Sobel filter and the 9-point version of the discrete Laplace operator. I expanded this by adding the option to use Scharr and Prewitt filters (operators) as well. The weights of these filters are illustrated below.
Sobel and Scharr filters have similar structure but a bit different weights. The Laplace operator is applied to all the filters. There is also an option to combine different filters, though this does increase the model size slightly. Still, these filters are relatively small tensors (3x3), focusing on local information and cannot be applied to whole image in one iteration.
The model’s forward pass follows this schema:
The features of the original images are processed through the filters (perchannel_conv
). The original features are then concatenated with the filtered features. Finally, the combined features are passed through several convolutional and ReLU layers. The more times this process is repeated, the more the image's style changes.
Training
The model is just one part of the necessary components; we also need appropriate loss functions. These play a crucial role in the model training process:
During training, the RGB channels of the images serve as the foundation (these are not updated), while additional hidden channels act as the training ground. When training the model, each batch is forwarded through the model 32 to 96 times, with the exact number chosen randomly within this range. Each time, the hidden features are updated. During training at certain number of batches, new image is inserted to batch (so training data does not remain the same). After completing training iteration, several losses are calculated:
- Style Loss: Measures how effectively the style of the reference image has been transferred.
- Content Loss: Evaluates how well the original content of the image has been preserved, ensuring that the image’s style is changed without destroying the original content.
- Color Loss: Assesses how accurately the colors from the style image have been transferred.
- Overflow Loss: Regularizes numerical overflows if the predicted image pixel values fall significantly outside the range of -1 to 1.
For the style loss, I experimented with different loss functions. The original implementation used a Gram loss-based approach, but I also added Optimal Transport and Vincent’s losses. Here’s a brief overview of these losses:
- Gram Loss. A popular style loss function that attempts to capture the correlation between different image features (usually obtained via a VGG model) by ignoring their spatial positions in the image. However, Gram Loss may not fully capture the distribution of these features.
- Optimal Transport (OT) Loss. This loss is rooted in optimal transport theory, which you can learn more about here. It aims to find the optimal “transport” from the source space (the model-generated image) to the target space (our style image). I am using a sliced Wasserstein loss version, with the calculation depicted in the following graph. First, the target and source features (denoted as p and p-hat) are projected in random directions (v). Then, the L² loss is calculated between the sorted lists of these projections.
- Vincent’s Loss. This loss was originally developed to address some of the limitations of Gram Loss. The core of this loss is based on the Wasserstein metric (which should not be confused with the sliced Wasserstein used in the previous loss):
Where:
- C¹ᶦ²₂ is principal square root of C₂ the covariance matrix of the style image. This is calculated by performing eigen decomposition, taking the square root of the eigenvalues, forming a diagonal matrix, and then multiplying it with the eigenvalues.
- m₁ and m₂ are the means of the generated and style images, respectively.
Vincent’s Loss is closely related to the optimal transport loss as it measures the distance between two probability distributions. However, unlike the sliced Wasserstein approach used previously, this method doesn’t rely on random projections. Instead, it utilizes information from the first and second moments (mean and (co)variance) of the generated and style images.
Now, let’s take a look at what this model has generated.
Outputs
I am using the following style and content images to demonstrate the effects of different parameters:
Next, let’s try to isolate the effects of different parameters. This task is not as straightforward as it might seem. Since various hyperparameters influence one another, changing one often necessitates adjustments to others in order to achieve the best results. Additionally, some level of randomness is inherent in the outputs. It’s quite common in generative AI to not achieve the best (or expected) result on the first try, leading to a deep dive into fine-tuning the parameters.
The first set of configuration changes relates to the loss functions. Keeping all other parameters unchanged, I trained models using Gram, OT, and Vincent’s losses:
The Gram Loss appears the blurriest, likely due to needing more training or generation rounds, or perhaps some fine-tuning of other hyperparameters. Beyond that, you can observe that Vincent’s Loss produces a grainier texture. Would I choose a single “best” loss? Not necessarily. Different images may respond better to different losses, and each can yield good results depending on the context.
Next, I adjusted the weights for the Style, Content, and Color Losses (0.1, 1.0, and 10.0, respectively). Now, some clear patterns emerge. The higher the weight, the more that component is emphasized in the output: a high content weight keeps the original image more intact, a higher style weight means the style is more dominant in the output, and a high Color Loss weight results in the style colors being more closely matched in the output.
These were very simple comparisons. Changing multiple hyperparameters simultaneously opens up a wide range of possibilities (and challenges) for discovery. The hyperparameters discussed so far are the primary ones to adjust, but the model contains many more that can be tweaked.
From these samples, it’s clear that some level of style transfer is indeed occurring. Adjusting certain parameters can have a significant impact on the output. These outputs demonstrate how the same style image can produce different results based on variations in hyperparameters.
Finally, I’ll showcase some of the more impressive transfers produced by the model. In the first batch, I used the same style image but varied the number of training rounds. The model didn’t transfer the style exactly but instead captured different aspects, such as color, texture, and lines.
The second batch of images features two different style images. In the first, you can clearly see that the model has successfully transferred the line styles. In the second image, the model captures a “cellular” or “cellish” style. Of course, there is potential for many more experiments and images, but I hope these samples demonstrate that even this simple model can effectively be used for style transfer.
How small is this model? Depending on the number of filters used (we can stack more than one filter, with the Laplace filter added for each), the model can have anywhere from 17 745 parameters (with one filter) to 20 985 parameters (with two filters). I removed the originally used blur layers, which I found made things more complicated. When saved to disk, the model size ranges from 86 to 88 KB (yes, you read that correctly — kilobytes!). This tiny size was achieved without any pruning, quantization, or distillation.
Conclusion
Neural Cellular Automata can indeed be used for style transfer. This simple model is capable of learning complex style patterns while preserving the content of the original image. What’s most fascinating is that it relies on filters (convolutions) that only perceive local context, yet, from this simplicity, complex and intricate images emerge.