Skip to content

3145. Find Products of Elements of Big Array

  • Time: $O(q\log^2 \max(\texttt{to[i]}))$
  • Space: $O(q)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class Solution {
 public:
  vector<int> findProductsOfElements(vector<vector<long long>>& queries) {
    vector<int> ans;

    for (const vector<long long>& query : queries) {
      const long a = query[0];
      const long b = query[1];
      const int mod = query[2];
      const int product = modPow(2,
                                 sumPowersFirstKBigNums(b + 1) -  //
                                     sumPowersFirstKBigNums(a),
                                 mod);
      ans.push_back(product);
    }

    return ans;
  }

 private:
  // Returns the sum of powers of the first k numbers in `big_nums`.
  long sumPowersFirstKBigNums(long k) {
    const long num = firstNumberHavingSumBitsTillGreaterThan(k);
    long sumPowers = sumPowersTill(num - 1);
    long remainingCount = k - sumBitsTill(num - 1);
    for (int power = 0; power < bitLength(num); ++power) {
      if (num >> power & 1) {
        sumPowers += power;
        --remainingCount;
        if (remainingCount == 0)
          break;
      }
    }
    return sumPowers;
  }

  // Returns the first number in [1, k] that has sumBitsTill(num) >= k.
  long firstNumberHavingSumBitsTillGreaterThan(long k) {
    long l = 1;
    long r = k;
    while (l < r) {
      const long m = (l + r) / 2;
      if (sumBitsTill(m) < k)
        l = m + 1;
      else
        r = m;
    }
    return l;
  }

  // Returns sum(i.bit_count()), where 1 <= i <= x.
  long sumBitsTill(long x) {
    long sumBits = 0;
    for (long powerOfTwo = 1; powerOfTwo <= x; powerOfTwo *= 2) {
      sumBits += (x / (2L * powerOfTwo)) * powerOfTwo;
      sumBits += max(0L, x % (2L * powerOfTwo) + 1 - powerOfTwo);
    }
    return sumBits;
  }

  // Returns sum(all powers of i), where 1 <= i <= x.
  long sumPowersTill(long x) {
    long sumPowers = 0;
    long powerOfTwo = 1;
    for (int power = 0; power < bitLength(x); ++power) {
      sumPowers += (x / (2L * powerOfTwo)) * powerOfTwo * power;
      sumPowers += max(0L, x % (2L * powerOfTwo) + 1 - powerOfTwo) * power;
      powerOfTwo *= 2;
    }
    return sumPowers;
  }

  int modPow(long x, long n, int mod) {
    if (n == 0)
      return 1 % mod;
    if (n % 2 == 1)
      return x * modPow(x % mod, (n - 1), mod) % mod;
    return modPow(x * x % mod, (n / 2), mod) % mod;
  }

  int bitLength(long x) {
    return x == 0 ? 0 : 64 - __builtin_clzl(x);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class Solution {
  public int[] findProductsOfElements(long[][] queries) {
    int[] ans = new int[queries.length];

    for (int i = 0; i < queries.length; ++i) {
      final long a = queries[i][0];
      final long b = queries[i][1];
      final int mod = (int) queries[i][2];
      ans[i] = (int) modPow(2,
                            sumPowersFirstKBigNums(b + 1) - //
                                sumPowersFirstKBigNums(a),
                            mod);
    }

    return ans;
  }

  // Returns the sum of powers of the first k numbers in `big_nums`.
  private long sumPowersFirstKBigNums(long k) {
    final long num = firstNumberHavingSumBitsTillGreaterThan(k);
    long sumPowers = sumPowersTill(num - 1);
    long remainingCount = k - sumBitsTill(num - 1);
    for (int power = 0; power < bitLength(num); ++power) {
      if ((num >> power & 1) == 1) {
        sumPowers += power;
        --remainingCount;
        if (remainingCount == 0)
          break;
      }
    }
    return sumPowers;
  }

  // Returns the first number in [1, k] that has sumBitsTill(num) >= k.
  private long firstNumberHavingSumBitsTillGreaterThan(long k) {
    long l = 1;
    long r = k;
    while (l < r) {
      final long m = (l + r) / 2;
      if (sumBitsTill(m) < k)
        l = m + 1;
      else
        r = m;
    }
    return l;
  }

  // Returns sum(i.bit_count()), where 1 <= i <= x.
  private long sumBitsTill(long x) {
    long sumBits = 0;
    for (long powerOfTwo = 1; powerOfTwo <= x; powerOfTwo *= 2) {
      sumBits += (x / (2 * powerOfTwo)) * powerOfTwo;
      sumBits += Math.max(0, x % (2 * powerOfTwo) + 1 - powerOfTwo);
    }
    return sumBits;
  }

  // Returns sum(all powers of i), where 1 <= i <= x.
  private long sumPowersTill(long x) {
    long sumPowers = 0;
    long powerOfTwo = 1;
    for (int power = 0; power < bitLength(x); ++power) {
      sumPowers += (x / (2 * powerOfTwo)) * powerOfTwo * power;
      sumPowers += Math.max(0, x % (2 * powerOfTwo) + 1 - powerOfTwo) * power;
      powerOfTwo *= 2;
    }
    return sumPowers;
  }

  private long modPow(long x, long n, int mod) {
    if (n == 0)
      return 1 % mod;
    if (n % 2 == 1)
      return x * modPow(x % mod, (n - 1), mod) % mod;
    return modPow(x * x % mod, (n / 2), mod) % mod;
  }

  private int bitLength(long x) {
    return x == 0 ? 0 : 64 - Long.numberOfLeadingZeros(x);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution:
  def findProductsOfElements(self, queries: list[list[int]]) -> list[int]:
    def sumBitsTill(x: int) -> int:
      """Returns sum(i.bit_count()), where 1 <= i <= x."""
      sumBits = 0
      powerOfTwo = 1
      while powerOfTwo <= x:
        sumBits += (x // (2 * powerOfTwo)) * powerOfTwo
        sumBits += max(0, x % (2 * powerOfTwo) + 1 - powerOfTwo)
        powerOfTwo *= 2
      return sumBits

    def sumPowersTill(x: int) -> int:
      """Returns sum(all powers of i), where 1 <= i <= x."""
      sumPowers = 0
      powerOfTwo = 1
      for power in range(x.bit_length()):
        sumPowers += (x // (2 * powerOfTwo)) * powerOfTwo * power
        sumPowers += max(0, x % (2 * powerOfTwo) + 1 - powerOfTwo) * power
        powerOfTwo *= 2
      return sumPowers

    def sumPowersFirstKBigNums(k: int) -> int:
      """Returns the sum of powers of the first k numbers in `big_nums`."""
      # Find the first number in [1, k] that has sumBitsTill(num) >= k.
      num = bisect.bisect_left(range(k), k, key=sumBitsTill)
      sumPowers = sumPowersTill(num - 1)
      remainingCount = k - sumBitsTill(num - 1)
      for power in range(num.bit_length()):
        if num >> power & 1:
          sumPowers += power
          remainingCount -= 1
          if remainingCount == 0:
            break
      return sumPowers

    return [pow(2,
                sumPowersFirstKBigNums(b + 1) -
                sumPowersFirstKBigNums(a), mod)
            for a, b, mod in queries]