Skip to content

3549. Multiply Two Polynomials

  • Time: $O(n\log n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class Solution {
 public:
  vector<long long> multiply(vector<int>& poly1, vector<int>& poly2) {
    const unsigned n1 = poly1.size();
    const unsigned n2 = poly2.size();
    const unsigned n = n1 + n2 - 1;
    const unsigned sz = 1 << std::bit_width(n - 1);

    // Prepare input arrays for FFT.
    vector<complex<double>> a(sz);
    vector<complex<double>> b(sz);

    // Copy polynomial coefficients to complex arrays.
    for (int i = 0; i < n1; ++i)
      a[i] = poly1[i];

    for (int i = 0; i < n2; ++i)
      b[i] = poly2[i];

    // Transform polynomials to frequency domain.
    fft(a, false);
    fft(b, false);

    // Multiply in frequency domain.
    for (int i = 0; i < sz; ++i)
      a[i] *= b[i];

    // Transform back to coefficient domain.
    fft(a, true);

    // Extract real parts as polynomial coefficients.
    vector<long long> ans(n);

    for (int i = 0; i < n; ++i)
      ans[i] = round(a[i].real());

    return ans;
  }

 private:
  // Fast Fourier Transform.
  void fft(vector<complex<double>>& a, bool inverse) {
    const unsigned n = a.size();

    // Bit-reversal permutation.
    for (int i = 1, j = 0; i < n; ++i) {
      int bit = n >> 1;
      for (; j & bit; bit >>= 1)
        j ^= bit;
      j ^= bit;
      if (i < j)
        swap(a[i], a[j]);
    }

    // FFT computation.
    for (int len = 2; len <= n; len *= 2) {
      const double angle = 2 * M_PI / len * (inverse ? -1 : 1);
      const complex<double> w_len(cos(angle), sin(angle));
      for (int i = 0; i < n; i += len) {
        complex<double> w(1);
        for (int j = 0; j < len / 2; ++j) {
          const complex<double> u = a[i + j];
          const complex<double> v = a[i + j + len / 2] * w;
          a[i + j] = u + v;
          a[i + j + len / 2] = u - v;
          w *= w_len;
        }
      }
    }

    // Normalize if inverse transform.
    if (inverse)
      for (complex<double>& x : a)
        x /= n;
  }
};
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class Complex {
  public double real;
  public double imag;

  public Complex(double real, double imag) {
    this.real = real;
    this.imag = imag;
  }

  public Complex add(Complex other) {
    return new Complex(this.real + other.real, this.imag + other.imag);
  }

  public Complex subtract(Complex other) {
    return new Complex(this.real - other.real, this.imag - other.imag);
  }

  public Complex multiply(Complex other) {
    return new Complex(this.real * other.real - this.imag * other.imag,
                       this.real * other.imag + this.imag * other.real);
  }

  public Complex divide(double scalar) {
    return new Complex(this.real / scalar, this.imag / scalar);
  }
}

class Solution {
  public long[] multiply(int[] poly1, int[] poly2) {
    final int n1 = poly1.length;
    final int n2 = poly2.length;
    final int n = n1 + n2 - 1;
    final int sz = 1 << bitLength(n - 1);

    // Prepare input arrays for FFT.
    Complex[] a = new Complex[sz];
    Complex[] b = new Complex[sz];

    // Initialize arrays with Complex objects
    for (int i = 0; i < sz; ++i) {
      a[i] = new Complex(0, 0);
      b[i] = new Complex(0, 0);
    }

    // Copy polynomial coefficients to complex arrays.
    for (int i = 0; i < n1; ++i)
      a[i] = new Complex(poly1[i], 0);

    for (int i = 0; i < n2; ++i)
      b[i] = new Complex(poly2[i], 0);

    // Transform polynomials to frequency domain.
    fft(a, false);
    fft(b, false);

    // Multiply in frequency domain.
    for (int i = 0; i < sz; ++i)
      a[i] = a[i].multiply(b[i]);

    // Transform back to coefficient domain.
    fft(a, true);

    // Extract real parts as polynomial coefficients.
    long[] ans = new long[n];

    for (int i = 0; i < n; ++i)
      ans[i] = Math.round(a[i].real);

    return ans;
  }

  private void fft(Complex[] a, boolean inverse) {
    final int n = a.length;

    // Bit-reversal permutation.
    for (int i = 1, j = 0; i < n; ++i) {
      int bit = n >> 1;
      for (; (j & bit) != 0; bit >>= 1)
        j ^= bit;
      j ^= bit;
      if (i < j)
        swap(a, i, j);
    }

    // FFT computation.
    for (int len = 2; len <= n; len *= 2) {
      final double angle = 2 * Math.PI / len * (inverse ? -1 : 1);
      final Complex wLen = new Complex(Math.cos(angle), Math.sin(angle));
      for (int i = 0; i < n; i += len) {
        Complex w = new Complex(1, 0);
        for (int j = 0; j < len / 2; ++j) {
          final Complex u = a[i + j];
          final Complex v = a[i + j + len / 2].multiply(w);
          a[i + j] = u.add(v);
          a[i + j + len / 2] = u.subtract(v);
          w = w.multiply(wLen);
        }
      }
    }

    // Normalize if inverse transform.
    if (inverse)
      for (int i = 0; i < n; ++i)
        a[i] = a[i].divide(n);
  }

  private void swap(Complex[] a, int i, int j) {
    Complex temp = a[i];
    a[i] = a[j];
    a[j] = temp;
  }

  private int bitLength(int n) {
    return Integer.SIZE - Integer.numberOfLeadingZeros(n);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
from numpy import array, round
from numpy.fft import fft, ifft


class Solution:
  def multiply(self, poly1: list[int], poly2: list[int]) -> list[int]:
    n1 = len(poly1)
    n2 = len(poly2)
    n = n1 + n2 - 1
    sz = 1 << (n - 1).bit_length()
    arr1 = array(poly1 + [0] * (sz - n1))
    arr2 = array(poly2 + [0] * (sz - n2))
    ans = ifft(fft(arr1) * fft(arr2))
    return round(ans).astype(int).tolist()[:n]