15-451 Algorithms 04/23/09 * Multiplying polynomials * Fast Fourier Transform (FFT) =========================================================================== Today we are going to talk about the Fast Fourier Transform, a widely used algorithm in areas like signal processing, speech understanding, digital audio, radar. We'll develop it in trying to solve the problem of "how to multiply quickly" we talked about on the first day of class. This is not how it was invented historically (and it obscures the connection to the usual kind of Fourier Transform), but I think it's most natural from the perspective of Algorithms. You don't need to know what a FT is for this lecture. Warning: even though the algorithm in the end won't be that complicated, this will be one of the most difficult lectures of the class! =================================================================== At the start of class we talked about the problem of multiplying big integers. Let's look at a simpler version of the problem, which is the problem of multiplying polynomials. It sounds more complicated, but really it's just the problem of multiplying integers without the carries. MULTIPLYING POLYNOMIALS ======================= E.g., multiply (x^2 + 2x + 3)(2x^2 + 5) = 2x^2 + 0x + 5 1x^2 + 2x + 3 ------------- 6 0 15 4 0 10 2 0 5 ---------------------------- 2x^4 + 4x^3 + 11x^2 +10x+15 MODEL: view each individual small multiplication as a unit cost operation. More generally, given A(x) = a_0 + a_1 x + ... + a_{n-1}x^{n-1}, B(x) = b_0 + b_1 x + ... + b_{n-1}x^{n-1} Our goal is to compute the polynomial C(x) = A(x)*B(x). c_i = a_0*b_i + a_1*b_{i-1} + ... + a_i*b_0. If we think of A and B as vectors, then the C-vector is called the "convolution" of A and B. - Straightforward computation is O(n^2) time. Karatuba is n^{1.58..} - we'll use FFTs to do in O(n log n) time. This is then used in Schonhage-Strassen integer multiplication algorithm that multiplies two n-bit integers in O(n log n loglog n) time. We're only going to do polynomial multiplication. High Level Idea of Algorithm ============================ Let m = 2n-1. [so degree of C is less than m] 1. Pick m points x_0, x_1, ..., x_{m-1} according to a secret formula. 2. Evaluate A at each of the points: A(x_0),..., A(x_{m-1}). 3. Same for B. 4. Now compute C(x_0),..., C(x_{m-1}), where C is A(x)*B(x) 5. Interpolate to get the coefficients of C. This approach is based on the fact that a polynomial of degree < m is uniquely specified by its value on m points. It seems patently crazy since it looks like steps 2 and 3 should take O(n^2) time just in themselves. However, the FFT will allow us to quickly move from "coefficient representation" of polynomial to the "value on m points" representation, and back, for our special set of m points. (Doesn't work for *arbitrary* m points. The special points will turn out to be roots of unity). The reason we like this is that multiplying is easy in the "value on m points" representation. We just do: C(x_i) = A(x_i)*B(x_i). So, only O(m) time for step 4. Let's focus on forward direction first. In that case, we've reduced our problem to the following: GOAL: Given a polynomial A of degree < m, evaluate A at m points of our choosing in total time O(m log m). Assume m is a power of 2. The FFT: ======= Let's first develop it through an example. Say m=8 so we have a polynomial A(x) = a_0 + a_1 x + a_2 x^2 + a_3 x^3 + a_4 x^4 + a_5 x^5 + a_6x^6 + a_7x^7. (as a vector, A = [a_0, a_1, ..., a_7]) And we want to evaluate at eight points of our choosing. Here is an idea. Split A into two pieces, but instead of left and right, have them be even and odd. So, as vectors, A_even = [a_0, a_2, a_4, a_6] A_odd = [a_1, a_3, a_5, a_7] or, as polynomials: A_even(x) = a_0 + a_2 x + a_4 x^2 + a_6 x^3 A_odd(x) = a_1 + a_3 x + a_5 x^2 + a_7 x^3. Each has degree < m/2. How can we write A(x) in terms of A_even and A_odd? A(x) = A_even(x^2) + x A_odd(x^2). What's nice is that the effort spent computing A(x) will give us A(-x) almost for free. So, let's say our special set of m points will have the property: The 2nd half of points are the negative of the 1st half (*) E.g., {1,2,3,4,-1,-2,-3,-4}. Now, things look good: Let T(m) = time to evaluate a degree-m polynomial at our special set of m points. We're doing this by evaluating two degree-m/2 polynomials at m/2 points each (the squares), and then doing O(m) work to combine the results. This is great because the recurrence T(m) = 2T(m/2) + O(m) solves to O(m log m). But, we're deluding ourselves by saying "just do it recursively". Why is that? The problem is that recursively, our special points (now {1, 4, 9, 16}) have to satisfy property (*). E.g., they should really look like {1, 4, -1, -4}. BUT THESE ARE SQUARES!! How to fix? Just use complex numbers! E.g., if these are the squares, what do the original points look like? {1, 2, i, 2i, -1, -2, -i, -2i} so then their squares are: 1, 4, -1, -4 and their squares are: 1, 16 But, at the next level we again need property (*). So, we want to have {1, -1} there. This means we want the level before that to be {1, i, -1, -i}, which is the same as {1, i, i^2, i^3}. So, for the original level, let w = sqrt(i) = 0.707 + 0.707i, and then our original set of points will be: 1, w, w^2, w^3, w^4 (= -1), w^5 (= -w), w^6 (= -w^2), w^7 (= -w^3) so that the squares are: 1, i, i^2 (= -1), i^3 (= -i) and *their* squares are: 1, -1 and *their* squares are: 1 The "w" we are using is called the "primitive eighth root of unity" (since w^8 = 1 and w^k != 1 for 0 < k < 8). In general, the mth primitive root of unity is the vector w = cos(2*pi/m) + i*sin(2*pi/m) Alternatively, we can use MODULAR ARITHMETIC! ============================================ E.g., 2 is a primitive 8th root of unity mod 17. {2^0,2^1,2^2,...,2^7} = {1,2,4,8 16,15,13, 9} = {1,2,4,8,-1,-2,-4,-8}. Then when you square them, you get {1,4,-1,-4}, etc. This is nice because we don't need to deal with messy floating-points. THE FFT ALGORITHM ================= Here is the general algorithm in pseudo-C: Let A be array of length m, w be primitive mth root of unity. Goal: produce DFT F(A): evaluation of A at 1, w, w^2,...,w^{m-1}. FFT(A, m, w) { if (m==1) return vector (a_0) else { A_even = (a_0, a_2, ..., a_{m-2}) A_odd = (a_1, a_3, ..., a_{m-1}) F_even = FFT(A_even, m/2, w^2) //w^2 is a primitive m/2-th root of unity F_odd = FFT(A_odd, m/2, w^2) F = new vector of length m x = 1 for (j=0; j < m/2; ++j) { F[j] = F_even[j] + x*F_odd[j] F[j+m/2] = F_even[j] - x*F_odd[j] x = x * w } return F } THE INVERSE OF THE FFT ====================== Remember, we started all this by saying that we were going to multiply two polynomials A and B by evaluating each at a special set of m points (which we can now do in time O(m log m)), then multiply the values point-wise to get C evalauated at all these points (in O(m) time) but then we need to interpolate back to get the coefficients. In other words, we're doing F^{-1}(F(A) \cdot F(B)). So, we need to compute F^{-1}. Here's how. First, we can look at the forward computation (computing A(x) at 1, w, w^2, ..., w^{m-1}) as an implicit matrix-vector product: +---------------------------------------+ +-----+ +-------+ | 1 1 1 1 ... 1 | | a_0 | | A(1) | | 1 w w^2 w^3 ... w^{m-1} | | a_1 | | A(w) | | 1 w^2 w^4 w^6 ... w^{2(m-1)} | | a_2 | == | A(w^2)| | 1 w^3 w^6 w^9 ... w^{3(m-1)} | | a_3 | | A(w^3)| | ..... ..... | | ... | | ... | | 1 w^{-1} w^{-2} w^{-3} ... w | |a_m-1| | ... | +---------------------------------------+ +-----+ +-------+ (Note w^{m-1} = w^{-1} since w^m = 1) (We're doing this "implicitly" in the sense that we don't even have time to write down the matrix.) To do the inverse transform, what we want is to multiply by the inverse of the F matrix. As it turns out, this inverse looks very much like F itself. In particular, notice that w^{-1} is also a principal mth root of unity. Let's define \bar{F} to be the fourier transform matrix using w^{-1} instead of w. Then, Claim: F^{-1} = (1/m) * \bar{F}. I.e., 1/m * \bar{F} * F = identity. Proof: What is the i,j entry of \bar{F} * F? It is: 1 + w^{j-i} + w^{2j-2i} + w^{3j-3i} + ... + w^{(m-1)j - (m-1)i} If i=j, then these are all = 1, so the sum is m. Then when we divide by m we get 1. If i!=j, then the claim is these all cancel out and we get zero. Maybe easier to see if we let z = w^{j-i}, so then the sum is: 1 + z + z^2 + z^3 + z^4 + ... + z^{m-1}. Then can see these cancel by picture. For instance, try z = w, z = w^2. Or can use the formula for summations: (1 - z^m)/(1-z) = 0/(1-z) = 0. So, the final algorithm is: Let F_A = FFT(A, m, w) // time O(n log n) Let F_B = FFT(B, m, w) // time O(n log n) For i=1 to m, let F_C[i] = F_A[i]*F_B[i] // time O(n) Output C = 1/m * FFT(F_C, m, w^{-1}). // time O(n log n) NOTE: If you're an EE or Physics person, what we're calling the "Fourier Transform" is what you would usually call the "inverse Fourier transform" and vice-versa.