Fast Fourier Transform
How to implement the Fast Fourier Transform algorithm in Python from scratch.
If you have a background in electrical engineering, you will, in all probability, have heard of the Fourier Transform. In layman’s terms, the Fourier Transform is a mathematical operation that changes the domain (x-axis) of a signal from time to frequency. The latter is particularly useful for decomposing a signal consisting of multiple pure frequencies. For more details have a look at the following video.
The application of the Fourier Transform isn’t limited to digital signal processing. The Fourier Transform can, in fact, speed up the training process of convolutional neural networks. Recall how a convolutional layer overlays a kernel on a section of an image and performs bit-wise multiplication with all of the values at that location. The kernel is then shifted to another section of the image and the process is repeated until it has traversed the entire image.
The Fourier Transform can speed up convolutions by taking advantage of the following property.
The above equation states that the convolution of two signals is equivalent to the multiplication of their Fourier transforms. Therefore, by transforming the input into frequency space, a convolution becomes a single element-wise multiplication. In other words, the input to a convolutional layer and kernel can be converted into frequencies using the Fourier Transform, multiplied once and then converted back using the inverse Fourier Transform. There is an overhead associated with transforming the inputs into the Fourier domain and the inverse Fourier Transform to get responses back to the spatial domain. However, this is offset by the speed up obtained from performing a single multiplication instead of having to multiply the kernel with different sections of the image.
Discrete Fourier Transform
The Discrete Fourier Transform (DTF) can be written as follows.
To determine the DTF of a discrete signal
N is the size of its domain), we multiply each of its value by
e raised to some function of
n. We then sum the results obtained for a given
n. If we used a computer to calculate the Discrete Fourier Transform of a signal, it would need to perform N (multiplications) x N (additions) = O(N²) operations.
As the name implies, the Fast Fourier Transform (FFT) is an algorithm that determines Discrete Fourier Transform of an input significantly faster than computing it directly. In computer science lingo, the FFT reduces the number of computations needed for a problem of size
On the surface, this might not seem like a big deal. However, when N is large enough it can make a world of difference. Have a look at the following table.
Say it took 1 nanosecond to perform one operation. It would take the Fast Fourier Transform algorithm approximately 30 seconds to compute the Discrete Fourier Transform for a problem of size
N = 10⁹. In contrast, the regular algorithm would need several decades.
Fast Fourier Transform Algorithm
Suppose, we separated the Fourier Transform into even and odd indexed sub-sequences.
After performing a bit of algebra, we end up with the summation of two terms. The advantage of this approach lies in the fact that the even and odd indexed sub-sequences can be computed concurrently.
N = 8 , to visualize the flow of data with time, we can make use of a butterfly diagram. We compute the Discrete Fourier Transform for the even and odd terms simultaneously. Then, we calculate
x[k] using the formula from above.
We can express the gains in terms of Big O Notation as follows. The first term comes from the fact that we compute the Discrete Fourier Transform twice. We multiply the latter by the time taken to compute the Discrete Fourier Transform on half the original input. In the final step, it takes
N steps to add up the Fourier Transform for a particular
k. We account for this by adding
N to the final product.
Notice how we were able to cut the time taken to compute the Fourier Transform by a factor of 2. We can further improve the algorithm by applying the divide-and-conquer approach, halving the computational cost each time. In other words, we can continue to split the problem size until we’re left with groups of two and then directly compute the Discrete Fourier Transforms for each of those pairs.
So long as
N is a power of 2, the maximum number of times you can split into two equal halves is given by
p = log(N).
Here’s what it would look like if we were to use the Fast Fourier Transform algorithm with a problem size of
N = 8. Notice how we have
p = log(8) = 3 stages.
Let’s take a look at how we could go about implementing the Fast Fourier Transform algorithm from scratch using Python. To begin, we import the
import numpy as np
Next, we define a function to calculate the Discrete Fourier Transform directly.
def dft(x): x = np.asarray(x, dtype=float) N = x.shape n = np.arange(N) k = n.reshape((N, 1)) M = np.exp(-2j * np.pi * k * n / N) return np.dot(M, x)
We can ensure our implementation is correct by comparing the results with those obtained from numpy’s
x = np.random.random(1024) np.allclose(dft(x), np.fft.fft(x))
As we can clearly see, the Discrete Fourier Transform function is orders of magnitude slower than the Fast Fourier Transform algorithm.
%timeit dft(x) %timeit np.fft.fft(x)
Like we saw before, the Fast Fourier Transform works by computing the Discrete Fourier Transform for small subsets of the overall problem and then combining the results. The latter can easily be done in code using recursion.
def fft(x): x = np.asarray(x, dtype=float) N = x.shape if N % 2 > 0: raise ValueError("must be a power of 2") elif N <= 2: return dft(x) else: X_even = fft(x[::2]) X_odd = fft(x[1::2]) terms = np.exp(-2j * np.pi * np.arange(N) / N) return np.concatenate([X_even + terms[:int(N/2)] * X_odd, X_even + terms[int(N/2):] * X_odd])
Again, we can validate whether our implementation is correct by comparing the results with those obtained from numpy.
x = np.random.random(1024) np.allclose(fft(x), np.fft.fft(x))
The FFT algorithm is significantly faster than the direct implementation. However, it still lags behind the numpy implementation by quite a bit. One reason for this is the fact that the numpy implementation uses matrix operations to calculate the Fourier Transforms simultaneously.
%timeit dft(x) %timeit fft(x) %timeit np.fft.fft(x)
We define another function to compute the Fourier Transform. Only this time around, we make use of vector operations instead of recursion.
def fft_v(x): x = np.asarray(x, dtype=float) N = x.shape if np.log2(N) % 1 > 0: raise ValueError("must be a power of 2") N_min = min(N, 2) n = np.arange(N_min) k = n[:, None] M = np.exp(-2j * np.pi * n * k / N_min) X = np.dot(M, x.reshape((N_min, -1)))
while X.shape < N: X_even = X[:, :int(X.shape / 2)] X_odd = X[:, int(X.shape / 2):] terms = np.exp(-1j * np.pi * np.arange(X.shape) / X.shape)[:, None] X = np.vstack([X_even + terms * X_odd, X_even - terms * X_odd])
Once again, we can ensure we obtained the correct results by comparing them with those from the numpy library.
x = np.random.random(1024) np.allclose(fft_v(x), np.fft.fft(x))
As we can see, the FFT implementation using vector operations is significantly faster than what we had obtained previously. We still haven’t come close to the speed at which the numpy library computes the Fourier Transform. This is because the FFTPACK algorithm behind numpy’s
fft is a Fortran implementation which has received years of tweaks and optimizations. If you are interested in finding out more, I recommend you have a look at the source code.
%timeit fft(x) %timeit fft_v(x) %timeit np.fft.fft(x)