Skip to content

3318. Find X-Sum of All K-Long Subarrays I

  • Time: $O(n\log n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution {
 public:
  vector<int> findXSum(vector<int>& nums, int k, int x) {
    vector<int> ans;
    long windowSum = 0;
    unordered_map<int, int> count;
    multiset<pair<int, int>> top;  // the top x elements
    multiset<pair<int, int>> bot;  // the rest of the elements

    // Updates the count of num by freq and the window sum accordingly.
    auto update = [&count, &top, &bot, &windowSum](int num, int freq) -> void {
      if (count[num] > 0) {  // Clean up the old count.
        if (auto it = bot.find({count[num], num}); it != bot.end()) {
          bot.erase(it);
        } else {
          it = top.find({count[num], num});
          top.erase(it);
          windowSum -= num * count[num];
        }
      }
      count[num] += freq;
      if (count[num] > 0)
        bot.insert({count[num], num});
    };

    for (int i = 0; i < nums.size(); ++i) {
      update(nums[i], 1);
      if (i >= k)
        update(nums[i - k], -1);
      // Move the bottom elements to the top if needed.
      while (!bot.empty() && top.size() < x) {
        const auto [countB, b] = *bot.rbegin();
        bot.erase(--bot.end());
        top.insert({countB, b});
        windowSum += b * countB;
      }
      // Swap the bottom and top elements if needed.
      while (!bot.empty() && *bot.rbegin() > *top.begin()) {
        const auto [countB, b] = *bot.rbegin();
        const auto [countT, t] = *top.begin();
        bot.erase(--bot.end());
        top.erase(top.begin());
        bot.insert({countT, t});
        top.insert({countB, b});
        windowSum += b * countB;
        windowSum -= t * countT;
      }
      if (i >= k - 1)
        ans.push_back(windowSum);
    }

    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class Solution {
  public int[] findXSum(int[] nums, int k, int x) {
    int[] ans = new int[nums.length - k + 1];
    Map<Integer, Integer> count = new HashMap<>();
    TreeSet<Pair<Integer, Integer>> top =
        new TreeSet<>(Comparator.comparingInt(Pair<Integer, Integer>::getKey)
                          .thenComparingInt(Pair<Integer, Integer>::getValue));
    TreeSet<Pair<Integer, Integer>> bot =
        new TreeSet<>(Comparator.comparingInt(Pair<Integer, Integer>::getKey)
                          .thenComparingInt(Pair<Integer, Integer>::getValue));

    for (int i = 0; i < nums.length; ++i) {
      update(nums[i], 1, count, top, bot);
      if (i >= k)
        update(nums[i - k], -1, count, top, bot);
      // Move the bottom elements to the top if needed.
      while (!bot.isEmpty() && top.size() < x) {
        Pair<Integer, Integer> pair = bot.pollLast();
        top.add(pair);
        windowSum += pair.getValue() * pair.getKey();
      }
      // Swap the bottom and top elements if needed.
      while (!bot.isEmpty() && (bot.last().getKey() > top.first().getKey() ||
                                bot.last().getKey().equals(top.first().getKey()) &&
                                    bot.last().getValue() > top.first().getValue())) {
        Pair<Integer, Integer> pairB = bot.pollLast();
        Pair<Integer, Integer> pairT = top.pollFirst();
        top.add(pairB);
        bot.add(pairT);
        windowSum += pairB.getValue() * pairB.getKey();
        windowSum -= pairT.getValue() * pairT.getKey();
      }
      if (i >= k - 1)
        ans[i - k + 1] = windowSum;
    }
    return ans;
  }

  private int windowSum = 0;

  // Updates the count of num by freq and the window sum accordingly.
  private void update(int num, int freq, Map<Integer, Integer> count,
                      TreeSet<Pair<Integer, Integer>> top, TreeSet<Pair<Integer, Integer>> bot) {
    if (count.getOrDefault(num, 0) > 0) { // Clean up the old count.
      Pair<Integer, Integer> pair = new Pair<>(count.get(num), num);
      if (bot.remove(pair)) {
        // Do nothing.
      } else {
        top.remove(pair);
        windowSum -= num * count.get(num);
      }
    }
    count.merge(num, freq, Integer::sum);
    if (count.get(num) > 0)
      bot.add(new Pair<>(count.get(num), num));
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from sortedcontainers import SortedList


class Solution:
  def findXSum(self, nums: list[int], k: int, x: int) -> list[int]:
    ans = []
    windowSum = 0
    count = collections.Counter()
    top = SortedList()
    bot = SortedList()

    def update(num: int, freq: int) -> None:
      """Updates the count of num by freq and the window sum accordingly."""
      nonlocal windowSum
      if count[num] > 0:  # Clean up old values.
        if [count[num], num] in bot:
          bot.remove([count[num], num])
        else:
          top.remove([count[num], num])
          windowSum -= num * count[num]
      count[num] += freq
      if count[num] > 0:
        bot.add([count[num], num])

    for i, num in enumerate(nums):
      update(num, 1)
      if i >= k:
        update(nums[i - k], -1)
      # Move the bottom element to the top if needed.
      while bot and len(top) < x:
        countB, b = bot.pop()
        top.add([countB, b])
        windowSum += b * countB
      # Swap the bottom and top elements if needed.
      while bot and bot[-1] > top[0]:
        countB, b = bot.pop()
        countT, t = top.pop(0)
        bot.add([countT, t])
        windowSum -= t * countT
        top.add([countB, b])
        windowSum += b * countB
      if i >= k - 1:
        ans.append(windowSum)

    return ans