Skip to content

3344. Maximum Sized Array 👍

  • Time: $O(\log^2 1196) = O(1)$
  • Space: $O(1)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Solution {
 public:
  int maxSizedArray(long long s) {
    if (s == 0)
      return 1;

    int l = 0;
    int r = 1196;  // when s = 10^15, n = 1196

    while (l < r) {
      const int m = (l + r + 1) / 2;
      if (getArraySum(m) <= s)
        l = m;
      else
        r = m - 1;
    }

    return l;
  }

 private:
  // Returns the number of integers in [0, n - 1] with the i-th bit set.
  //
  // For the i-th bit, numbers in the range [0, n - 1] can be divided into
  // groups of 2^(i + 1) numbers. In each group, exactly half of the numbers
  // have the i-th bit set.
  int getNumbersWithBitSet(int n, int i) {
    const int groupSize = 1 << (i + 1);
    const int halfGroupSize = 1 << i;
    const int fullGroups = n / groupSize;
    const int remaining = max(0, (n % groupSize) - halfGroupSize);
    return fullGroups * halfGroupSize + remaining;
  }

  // Returns the sum of all i * (j OR k) values in 3D arrays of size n^3.
  //
  //   sum(i * (j OR k)), where 0 <= i, j, k < n
  // = 0 * (j OR k) + 1 * (j OR k) + ... + (n - 1) * (j OR k)
  // = (0 + 1 + ... + n - 1) * sum(j OR k)
  // = (n * (n - 1) / 2) * sum(j OR k)
  long getArraySum(int n) {
    const int arithmeticSum = n * (n - 1) / 2;
    long orSum = 0;
    for (int i = 0; i < bitLength(n); ++i) {
      const int numbersWituoutBit = n - getNumbersWithBitSet(n, i);
      const int pairsWithBit =
          (n * n) - (numbersWituoutBit * numbersWituoutBit);
      orSum += pairsWithBit * (1L << i);
    }
    return arithmeticSum * orSum;
  }

  int bitLength(int n) {
    return 32 - __builtin_clz(n);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Solution {
  public int maxSizedArray(long s) {
    if (s == 0)
      return 1;

    int l = 0;
    int r = 1196; // when s = 10^15, n = 1196

    while (l < r) {
      final int m = (l + r + 1) / 2;
      if (getArraySum(m) <= s)
        l = m;
      else
        r = m - 1;
    }

    return l;
  }

  // Returns the number of integers in [0, n - 1] with the i-th bit set.
  //
  // For the i-th bit, numbers in the range [0, n - 1] can be divided into
  // groups of 2^(i + 1) numbers. In each group, exactly half of the numbers
  // have the i-th bit set.
  private int getNumbersWithBitSet(int n, int i) {
    final int groupSize = 1 << (i + 1);
    final int halfGroupSize = 1 << i;
    final int fullGroups = n / groupSize;
    final int remaining = Math.max(0, (n % groupSize) - halfGroupSize);
    return fullGroups * halfGroupSize + remaining;
  }

  // Returns the sum of all i * (j OR k) values in 3D arrays of size n^3.
  //
  //   sum(i * (j OR k)), where 0 <= i, j, k < n
  // = 0 * (j OR k) + 1 * (j OR k) + ... + (n - 1) * (j OR k)
  // = (0 + 1 + ... + n - 1) * sum(j OR k)
  // = (n * (n - 1) / 2) * sum(j OR k)
  private long getArraySum(int n) {
    final int arithmeticSum = n * (n - 1) / 2;
    long orSum = 0;
    for (int i = 0; i < bitLength(n); i++) {
      int numbersWithoutBit = n - getNumbersWithBitSet(n, i);
      long pairsWithBit = (long) n * n - (long) numbersWithoutBit * numbersWithoutBit;
      orSum += pairsWithBit * (1L << i);
    }
    return arithmeticSum * orSum;
  }

  private int bitLength(int n) {
    return 32 - Integer.numberOfLeadingZeros(n);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Solution:
  def maxSizedArray(self, s: int) -> int:
    def getNumbersWithBitSet(n: int, i: int) -> int:
      """
      Returns the number of integers in [0, n - 1] with the i-th bit set.

      For the i-th bit, numbers in the range [0, n - 1] can be divided into
      groups of 2^(i + 1) numbers. In each group, exactly half of the numbers
      have the i-th bit set.
      """
      groupSize = 1 << (i + 1)
      halfGroupSize = 1 << i
      fullGroups = n // groupSize
      remaining = max(0, (n % groupSize) - halfGroupSize)
      return fullGroups * halfGroupSize + remaining

    def getArraySum(n: int) -> int:
      """
      Returns the sum of all i * (j OR k) values in 3D arrays of size n^3.

        sum(i * (j OR k)), where 0 <= i, j, k < n
      = 0 * (j OR k) + 1 * (j OR k) + ... + (n - 1) * (j OR k)
      = (0 + 1 + ... + n - 1) * sum(j OR k)
      = (n * (n - 1) / 2) * sum(j OR k)
      """
      arithmeticSum = n * (n - 1) // 2  # 0 + 1 + ... + n - 1
      orSum = 0  # the sum of (j OR k) values in 2D arrays of size n^2
      for i in range(n.bit_length()):
        numbersWithoutBit = n - getNumbersWithBitSet(n, i)
        pairsWithBit = n**2 - numbersWithoutBit**2
        orSum += pairsWithBit * (1 << i)  # Add contribution of this bit.
      return arithmeticSum * orSum

    if s == 0:
      return 1
    l = 0
    r = 1196  # when s = 10^15, n = 1196
    return bisect.bisect_right(range(l, r + 1), s, key=getArraySum) - 1 + l