Skip to content

3422. Minimum Operations to Make Subarray Elements Equal

  • Time: $O(n\log k)$
  • Space: $O(k)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
struct SumMultiset {
  multiset<int> nums;
  long sum = 0;

  void insert(int val) {
    nums.insert(val);
    sum += val;
  }

  void erase(int val) {
    nums.erase(nums.find(val));
    sum -= val;
  }
};

class MedianTracker {
 public:
  MedianTracker(int k) : k(k) {}

  void add(int val) {
    below.insert(val);
    balance();
  }

  void remove(int val) {
    if (const auto it = below.nums.find(val); it != below.nums.cend())
      below.erase(val);
    else
      above.erase(val);
  }

  long getCost() const {
    return above.sum - below.sum - (k % 2 == 1 ? *above.nums.begin() : 0);
  }

 private:
  SumMultiset below;
  SumMultiset above;
  const int k;

  void balance() {
    // Move excessive numbers from `below` to `above`.
    while (below.nums.size() > k / 2) {
      const int mx = *prev(below.nums.end());
      below.erase(mx);
      above.insert(mx);
    }

    // Balance `below` and `above`.
    while (!above.nums.empty()) {
      const int mx = *prev(below.nums.end());
      const int mn = *above.nums.begin();
      if (mx <= mn)
        break;
      below.erase(mx);
      above.erase(mn);
      below.insert(mn);
      above.insert(mx);
    }
  }
};

class Solution {
 public:
  long long minOperations(vector<int>& nums, int k) {
    MedianTracker tracker(k);

    for (int i = 0; i < k; ++i)
      tracker.add(nums[i]);

    long ans = tracker.getCost();

    for (int i = k; i < nums.size(); ++i) {
      tracker.remove(nums[i - k]);
      tracker.add(nums[i]);
      ans = min(ans, tracker.getCost());
    }

    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class SumMultiset {
  public TreeMap<Integer, Integer> nums = new TreeMap<>();
  public long sum = 0;
  public int size = 0;

  public void insert(int val) {
    nums.merge(val, 1, Integer::sum);
    sum += val;
    ++size;
  }

  public void erase(int val) {
    nums.merge(val, -1, Integer::sum);
    if (nums.get(val) == 0)
      nums.remove(val);
    sum -= val;
    --size;
  }
}

class MedianTracker {
  public MedianTracker(int k) {
    this.k = k;
  }

  public void add(int val) {
    below.insert(val);
    balance();
  }

  public void remove(int val) {
    if (below.nums.containsKey(val))
      below.erase(val);
    else
      above.erase(val);
  }

  public long getCost() {
    return above.sum - below.sum - (k % 2 == 1 ? above.nums.firstKey() : 0L);
  }

  private SumMultiset below = new SumMultiset();
  private SumMultiset above = new SumMultiset();
  private int k;

  private void balance() {
    // Move excessive numbers from `below` to `above`.
    while (below.size > k / 2) {
      final int mx = below.nums.lastKey();
      below.erase(mx);
      above.insert(mx);
    }

    // Balance `below` and `above`.
    while (!above.nums.isEmpty()) {
      final int mx = below.nums.lastKey();
      final int mn = above.nums.firstKey();
      if (mx <= mn)
        break;
      below.erase(mx);
      above.erase(mn);
      below.insert(mn);
      above.insert(mx);
    }
  }
}

class Solution {
  public long minOperations(int[] nums, int k) {
    MedianTracker tracker = new MedianTracker(k);

    for (int i = 0; i < k; ++i)
      tracker.add(nums[i]);

    long ans = tracker.getCost();

    for (int i = k; i < nums.length; ++i) {
      tracker.remove(nums[i - k]);
      tracker.add(nums[i]);
      ans = Math.min(ans, tracker.getCost());
    }

    return ans;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution:
  def minOperations(self, nums: list[int], k: int) -> int:
    window = SortedList(nums[:k])
    median = window[(k - 1) // 2]
    ops = sum(abs(median - nums[j]) for j in range(k))
    ans = ops

    for i in range(k, len(nums)):
      window.remove(nums[i - k])
      window.add(nums[i])
      ops -= abs(median - nums[i - k])
      ops += abs(median - nums[i])
      newMedian = window[(k - 1) // 2]
      medianMultiplier = (
          2
          if k % 2 == 0 and median <= newMedian <= window[k // 2]
          else k % 2)
      ops -= abs(newMedian - median) * medianMultiplier
      median = newMedian
      ans = min(ans, ops)

    return ans