Skip to content

3013. Divide an Array Into Subarrays With Minimum Cost II 👍

  • Time: $O(n\log n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class Solution {
 public:
  long long minimumCost(vector<int>& nums, int k, int dist) {
    // Equivalently, the problem is to find nums[0] + the minimum sum of the top
    // k - 1 numbers in nums[i..i + dist], where i > 0 and i + dist < n.
    long windowSum = 0;
    multiset<int> selected;
    multiset<int> candidates;

    for (int i = 1; i <= dist + 1; ++i) {
      windowSum += nums[i];
      selected.insert(nums[i]);
    }

    windowSum = balance(windowSum, selected, candidates, k);
    long minWindowSum = windowSum;

    for (int i = dist + 2; i < nums.size(); ++i) {
      const int outOfScope = nums[i - dist - 1];
      if (selected.find(outOfScope) != selected.end()) {
        windowSum -= outOfScope;
        selected.erase(selected.find(outOfScope));
      } else {
        candidates.erase(candidates.find(outOfScope));
      }
      if (nums[i] < *selected.rbegin()) {  // nums[i] is a better number.
        windowSum += nums[i];
        selected.insert(nums[i]);
      } else {
        candidates.insert(nums[i]);
      }
      windowSum = balance(windowSum, selected, candidates, k);
      minWindowSum = min(minWindowSum, windowSum);
    }

    return nums[0] + minWindowSum;
  }

 private:
  // Returns the updated `windowSum` by balancing the multiset `selected` to
  // keep the top k - 1 numbers.
  long balance(long windowSum, multiset<int>& selected,
               multiset<int>& candidates, int k) {
    while (selected.size() < k - 1) {
      const int minCandidate = *candidates.begin();
      windowSum += minCandidate;
      selected.insert(minCandidate);
      candidates.erase(candidates.find(minCandidate));
    }
    while (selected.size() > k - 1) {
      const int maxSelected = *selected.rbegin();
      windowSum -= maxSelected;
      selected.erase(selected.find(maxSelected));
      candidates.insert(maxSelected);
    }
    return windowSum;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class Multiset {
  public void add(int num) {
    map.merge(num, 1, Integer::sum);
    ++sz;
  }

  public void remove(int num) {
    map.merge(num, -1, Integer::sum);
    if (map.get(num) == 0)
      map.remove(num);
    --sz;
  }

  public int min() {
    return map.firstEntry().getKey();
  }

  public int max() {
    return map.lastEntry().getKey();
  }

  public int size() {
    return sz;
  }

  public boolean contains(int num) {
    return map.containsKey(num);
  }

  private TreeMap<Integer, Integer> map = new TreeMap<>();
  private int sz = 0;
}

class Solution {
  public long minimumCost(int[] nums, int k, int dist) {
    // Equivalently, the problem is to find nums[0] + the minimum sum of the top
    // k - 1 numbers in nums[i..i + dist], where i > 0 and i + dist < n.
    long windowSum = 0;
    Multiset selected = new Multiset();
    Multiset candidates = new Multiset();

    for (int i = 1; i <= dist + 1; ++i) {
      windowSum += nums[i];
      selected.add(nums[i]);
    }

    windowSum = balance(windowSum, selected, candidates, k);
    long minWindowSum = windowSum;

    for (int i = dist + 2; i < nums.length; ++i) {
      final int outOfScope = nums[i - dist - 1];
      if (selected.contains(outOfScope)) {
        windowSum -= outOfScope;
        selected.remove(outOfScope);
      } else {
        candidates.remove(outOfScope);
      }
      if (nums[i] < selected.max()) { // nums[i] is a better number.
        windowSum += nums[i];
        selected.add(nums[i]);
      } else {
        candidates.add(nums[i]);
      }
      windowSum = balance(windowSum, selected, candidates, k);
      minWindowSum = Math.min(minWindowSum, windowSum);
    }

    return nums[0] + minWindowSum;
  }

  // Returns the updated `windowSum` by balancing the multiset `selected` to
  // keep the top k - 1 numbers.
  private long balance(long windowSum, Multiset selected, Multiset candidates, int k) {
    while (selected.size() < k - 1) {
      final int minCandidate = candidates.min();
      windowSum += minCandidate;
      selected.add(minCandidate);
      candidates.remove(minCandidate);
    }
    while (selected.size() > k - 1) {
      final int maxSelected = selected.max();
      windowSum -= maxSelected;
      selected.remove(maxSelected);
      candidates.add(maxSelected);
    }
    return windowSum;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from sortedcontainers import SortedList


class Solution:
  def minimumCost(self, nums: list[int], k: int, dist: int) -> int:
    # Equivalently, the problem is to find nums[0] + the minimum sum of the top
    # k - 1 numbers in nums[i..i + dist], where i > 0 and i + dist < n.
    windowSum = sum(nums[i] for i in range(1, dist + 2))
    selected = SortedList(nums[i] for i in range(1, dist + 2))
    candidates = SortedList()

    def balance() -> int:
      """
      Returns the updated `windowSum` by balancing the multiset `selected` to
      keep the top k - 1 numbers.
      """
      nonlocal windowSum
      while len(selected) < k - 1:
        minCandidate = candidates[0]
        windowSum += minCandidate
        selected.add(minCandidate)
        candidates.remove(minCandidate)
      while len(selected) > k - 1:
        maxSelected = selected[-1]
        windowSum -= maxSelected
        selected.remove(maxSelected)
        candidates.add(maxSelected)
      return windowSum

    windowSum = balance()
    minWindowSum = windowSum

    for i in range(dist + 2, len(nums)):
      outOfScope = nums[i - dist - 1]
      if outOfScope in selected:
        windowSum -= outOfScope
        selected.remove(outOfScope)
      else:
        candidates.remove(outOfScope)
      if nums[i] < selected[-1]:  # nums[i] is a better number.
        windowSum += nums[i]
        selected.add(nums[i])
      else:
        candidates.add(nums[i])
      windowSum = balance()
      minWindowSum = min(minWindowSum, windowSum)

    return nums[0] + minWindowSum