Adaptive width CNNs?

I was training a 2D and 3D VAE today for cosmological fields. Nothing special had to be done for 2D dark matter fields, they can be just treated as 256x256 images and GPU go brrr…

But for 3D fields… I had to reduce the width significantly since fitting a batch of 4 images of 2563 really is a lot of memory. If the first convolution gives a hidden channel of 48 (already smaller than the original U-Net), 4x48x2563 is already 12 GB. And that’s just one convolution’s activation, then there are gradients and activations following nonlinearities… etc..

So I had to reduce the number of channels of the hidden representations. But how much can I afford to? There is a memory to performance trade-off here. I hope I can reduce the width a bit since I only need to fit my dark matter fields and not all sort of natural images.