Long Convolutions via Polynomial Multiplication
Comments
touisteur
cochlear
I think I'm beginning to wrap my head around the way modern, "deep" state-space models (e.g Mamba, S4, etc.) leverage polynomial multiplication to speed up very long convolutions.
I'm curious if there are other methods for approximating long convolutions that are well-known or widely-used, outside of overlap-add and overlap-save? I'm in the audio field and interested in learning long _FIR_ filters to describe the resonances of physical objects, like instruments, or rooms. Block-coding, or fixed-frame size approaches reign supreme, of course, but have their own issues in terms of windowing artifacts, etc.
I'm definitely aware that multiplication in the (complex) frequency domain is equivalent to convolution in the time domain and that, because of the fast-fourier transform, this can yield increased efficiency. However, this still results in storing a lot of gradient information that my intuition tells me (possibly incorrectly) is full of redundancy and waste.
Stateful, IIR, or auto-regressive approaches are _one_ obvious answer, but this changes the game in terms of training and inference parallelization.
A couple ideas I've considered, but have not yet tried, or looked too deeply into:
- First performing PCA in the complex frequency domain, reducing the point-wise multiplication that must occur. Without some additional normalization up-front, it's likely this would be equivalent to downsampling/low-pass filtering and performing the convolution there. The learnable filter bank would live in the PCA space, reducing the overall number of learned parameters.
- A Compressed Sensing inspired approach, where we perform a sparse, sub-sampled random set of points from both signals and recover the full result based on the assumption that both convolver and convolvee? are sparse in the fourier domain. This one is pretty half-baked.
I'd love to hear about papers you've read, or thoughts you've had about this problem.
touisteur
The convolution by FFT overlap and save can have very low intermediate storage (none on GPU with cuFFTDx for example). And most of the time, the IFFT doesn't have to happen right away, lots processing can still be performed in the frequential domain.
Having each of 18k CUDA cores of a L40s perform small 128-points FFTs and with very little sync or overlap manage long filters... is pretty efficient by itself.
There's a lot happening in the HPC world on low-rank (what you're intuiting with PCA), sparse and tiled operations. I have a hard time applying all this to 'simple' signal processing and most of it lacks nicer APIs.
I've seen lots of interesting things with 'irregular FFT' codes and working on reducing either the storage space necessary for FFT intermediate results, sometimes through multi-resolution tricks.
Look up Capon filters and adaptative filtering in general, there's a whole world of tricks there too. You might need a whole lot of SVDs and matrix inversions there...
Bust mostly if you're on a GPU there's a wealth of parallelism to exploit and work-around the 'memory-bound' limits of FFT-based convolution. This thesis https://theses.hal.science/tel-04542844 had some discussion and numbers on the topic. Not complete but inspiring.
wbl
The gradient information in backroom can be computed similarly to forwards I think. Certainly the FFT blocks are linear and so now it's a question about the multiplication which is pretty compact.
almostgotcaught
I don't understand this reasoning. Depending on the target (GPU/DSP/FPGA), going to frequency domain for convolution made some amount of sense when FFT primitives were highly optimized relative to conventional conv or matmul implementations. But now we're like 10 years into the software/hardware arms race and the conv/matmul kernels are just as highly optimized. In addition the hardware has adapted too.
> using Tensor Cores for FFT
Why would I do this when I could just directly use tensor cores for matmul...? We have MMA, WMMA, WGMMA, etc and they all target tensor cores explicitly.
program_whiz
The time complexity of a large matrix multiplication is still much higher than using fourier, for large matrices it has superior performance.
touisteur
Exactly this. I was thinking exactly like GP but I've been doing a large amount of benchmarking on this and the FFT rapidly overcomes direct convolution with cudnn, cublas, cutlass... I think I'd seen a recent Phd thesis exploring and confirming this recently. NlogN beats N^2 or N^3 quickly, even with tensor cores. At some point complexity overcomes even the bestest hardware optimizations.
And the longer the convolution the more matrices look tall-skinny (less optimized). Also you have to duplicate a lot of data to make your matrix toeplitz/circulant and fit into matmul kernels which convolution is a special case...
almostgotcaught
> for large matrices it has superior performance.
how large? and umm citation please?
touisteur
You can mostly compute it for yourself or get an asymptotic feeling. Matmul is N^3 even with tensor cores eating a large part of it (with diminished precision, but ok) and FFT-based convolution is NlogN mostly. At some point - not that high - even the available TFLOPS available on tensor cores can't keep up. I was surprised by how far cuDNN will take you, but at some point the curves cross and matmul-based convolution time keeps growing at polynomial rate and FFT-based mostly linear (until it gets memory bound) and even that can be slashed down with block-based FFT schemes. Look up the thesis I posted earlier there's an introduction in it.
almostgotcaught
> Matmul is N^3
are we talking about matmul or conv? because no one, today, is using matmul for conv (at least not gemm - igemm isn't the same thing).
like i said - the arms race has been going for ~10 years and everyone already discovered divide-conquer and twiddle factors a long time ago. if you do `nm -C libcudnn.so` you will see lots of mentions of winograd and it's not because the cudnn team likes grapes.
> Look up the thesis I posted earlier there's an introduction in it.
there are a billion theses on this. including mine. there are infinte tricks/approaches/methods for different platforms. that's why saying something like "fft is best" is silly.
touisteur
You ask how large and citation, and given one with sizes, you already have a billion ones. The article is about long convolutions (winograd being still a matmul implementation for square matrices and a pain to adapt to long convolutions and still N^2.3), and given comparisons with 10-years-optimized-(as you say) cuDNN and very naive barely running FFT-based long convolutions, showing actual O complexity matters at some point - even with the best hardware tricks and implementations.
I don't know what more to say ? Don't use FFT-based convolution for long convolutions if it doesn't work for your use cases or if you don't believe it would or should. And those of us who have benchmarked against SOTA direct convolution and found out FFT-based convolution worked better for our use cases will keep using it and talk about it when people ask on forums ?
almostgotcaught
> I don't know what more to say ?
Do you understand that you can't magic away the complexity bound on conv and matmul by simply taking the fft? I know the you do so given this very very obvious fact, there are only two options for how an fft approach could beat XYZ conventional kernel
1. The fft primitive you're using for your platform is more highly optimized due to sheer attention/effort/years prior to basically 2010. FFTW could fall into this category on some platforms.
2. The shape/data-layout you're operating on is particularly suited to the butterfly ops in cooley-tukey
That's it. There is no other possibility because again fft isn't some magical oracle for conv - it's literally just a linear mapping right?
So taking points 1 and 2 together you arrive at the implication: whatever you're doing/seeing/benching isn't general enough for anyone to care. I mean think about it: do you think the cudnn org doesn't know about cooley-tukey? Like just so happens they've completely slept on this method that's taught in like every single undergrad signals and systems class? So that must mean it's not a coincidence that fft doesn't rate as highly as you think it does. If you disagree just write your fftdnn library that revolutions conv perf for the whole world and collect your fame and fortune from every single FAANG that currently uses cudnn/cublas.
human_llm
Also known as Z-transform in digital signal processing.
JimmyWilliams1
[dead]
Talking long convolutions, on parallel architectures, through FFT there's a lot of performance to gain from Overlap-and-Add or Overlap-and-Save schemes, especially on GPUs with cuFFTDx which brings a whole new set of device primitives to the table and looking at (or getting inspiration from) tcFFT which allows using Tensor Cores for FFT and actually increases throughput on lots of convolution workloads.