Skip to content

327. Count of Range Sum 👍

  • Time: $O(n\log n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class Solution {
 public:
  int countRangeSum(vector<int>& nums, int lower, int upper) {
    const int n = nums.size();
    int ans = 0;
    vector<long> prefix(n + 1);

    for (int i = 0; i < n; ++i)
      prefix[i + 1] = prefix[i] + nums[i];

    mergeSort(prefix, 0, n, lower, upper, ans);
    return ans;
  }

 private:
  void mergeSort(vector<long>& prefix, int l, int r, int lower, int upper,
                 int& ans) {
    if (l >= r)
      return;

    const int m = (l + r) / 2;
    mergeSort(prefix, l, m, lower, upper, ans);
    mergeSort(prefix, m + 1, r, lower, upper, ans);
    merge(prefix, l, m, r, lower, upper, ans);
  }

  void merge(vector<long>& prefix, int l, int m, int r, int lower, int upper,
             int& ans) {
    int lo = m + 1;  // 1st index s.t. prefix[lo] - prefix[i] >= lower
    int hi = m + 1;  // 1st index s.t. prefix[hi] - prefix[i] > upper

    // For each index i in range [l, m], add hi - lo to ans
    for (int i = l; i <= m; ++i) {
      while (lo <= r && prefix[lo] - prefix[i] < lower)
        ++lo;
      while (hi <= r && prefix[hi] - prefix[i] <= upper)
        ++hi;
      ans += hi - lo;
    }

    vector<long> sorted(r - l + 1);
    int k = 0;      // sorted's index
    int i = l;      // left's index
    int j = m + 1;  // right's index

    while (i <= m && j <= r)
      if (prefix[i] < prefix[j])
        sorted[k++] = prefix[i++];
      else
        sorted[k++] = prefix[j++];

    // Put possible remaining left part to the sorted array
    while (i <= m)
      sorted[k++] = prefix[i++];

    // Put possible remaining right part to the sorted array
    while (j <= r)
      sorted[k++] = prefix[j++];

    copy(begin(sorted), end(sorted), begin(prefix) + l);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Solution {
  public int countRangeSum(int[] nums, int lower, int upper) {
    final int n = nums.length;
    long[] prefix = new long[n + 1];

    for (int i = 0; i < n; ++i)
      prefix[i + 1] = (long) nums[i] + prefix[i];

    mergeSort(prefix, 0, n, lower, upper);
    return ans;
  }

  private int ans = 0;

  private void mergeSort(long[] prefix, int l, int r, int lower, int upper) {
    if (l >= r)
      return;

    final int m = (l + r) / 2;
    mergeSort(prefix, l, m, lower, upper);
    mergeSort(prefix, m + 1, r, lower, upper);
    merge(prefix, l, m, r, lower, upper);
  }

  private void merge(long[] prefix, int l, int m, int r, int lower, int upper) {
    int lo = m + 1; // 1st index s.t. prefix[lo] - prefix[i] >= lower
    int hi = m + 1; // 1st index s.t. prefix[hi] - prefix[i] > upper

    // For each index i in range [l, m], add hi - lo to ans
    for (int i = l; i <= m; ++i) {
      while (lo <= r && prefix[lo] - prefix[i] < lower)
        ++lo;
      while (hi <= r && prefix[hi] - prefix[i] <= upper)
        ++hi;
      ans += hi - lo;
    }

    long[] sorted = new long[r - l + 1];
    int k = 0;     // sorted's index
    int i = l;     // left's index
    int j = m + 1; // right's index

    while (i <= m && j <= r)
      if (prefix[i] < prefix[j])
        sorted[k++] = prefix[i++];
      else
        sorted[k++] = prefix[j++];

    // Put possible remaining left part to the sorted array
    while (i <= m)
      sorted[k++] = prefix[i++];

    // Put possible remaining right part to the sorted array
    while (j <= r)
      sorted[k++] = prefix[j++];

    System.arraycopy(sorted, 0, prefix, l, sorted.length);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class Solution:
  def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
    n = len(nums)
    self.ans = 0
    prefix = [0] + list(itertools.accumulate(nums))

    self._mergeSort(prefix, 0, n, lower, upper)
    return self.ans

  def _mergeSort(self, prefix: List[int], l: int, r: int, lower: int, upper: int) -> None:
    if l >= r:
      return

    m = (l + r) // 2
    self._mergeSort(prefix, l, m, lower, upper)
    self._mergeSort(prefix, m + 1, r, lower, upper)
    self._merge(prefix, l, m, r, lower, upper)

  def _merge(self, prefix: List[int], l: int, m: int, r: int, lower: int, upper: int) -> None:
    lo = m + 1  # 1st index s.t. prefix[lo] - prefix[i] >= lower
    hi = m + 1  # 1st index s.t. prefix[hi] - prefix[i] > upper

    # For each index i in range [l, m], add hi - lo to ans
    for i in range(l, m + 1):
      while lo <= r and prefix[lo] - prefix[i] < lower:
        lo += 1
      while hi <= r and prefix[hi] - prefix[i] <= upper:
        hi += 1
      self.ans += hi - lo

    sorted = [0] * (r - l + 1)
    k = 0      # sorted's index
    i = l      # left's index
    j = m + 1  # right's index

    while i <= m and j <= r:
      if prefix[i] < prefix[j]:
        sorted[k] = prefix[i]
        k += 1
        i += 1
      else:
        sorted[k] = prefix[j]
        k += 1
        j += 1

    # Put possible remaining left part to the sorted array
    while i <= m:
      sorted[k] = prefix[i]
      k += 1
      i += 1

    # Put possible remaining right part to the sorted array
    while j <= r:
      sorted[k] = prefix[j]
      k += 1
      j += 1

    prefix[l:l + len(sorted)] = sorted