Skip to content

2333. Minimum Sum of Squared Difference 👍

  • Time: $O(n\log n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution {
 public:
  long long minSumSquareDiff(vector<int>& nums1, vector<int>& nums2, int k1,
                             int k2) {
    const vector<int> diff = getDiff(nums1, nums2);
    int k = k1 + k2;
    if (accumulate(diff.begin(), diff.end(), 0L) <= k)
      return 0;

    unordered_map<int, int> count;
    priority_queue<pair<int, int>> maxHeap;  // (num, freq)

    for (const int d : diff)
      if (d != 0)
        ++count[d];

    for (const auto& [num, freq] : count)
      maxHeap.emplace(num, freq);

    while (k > 0) {
      const auto [maxNum, maxNumFreq] = maxHeap.top();
      maxHeap.pop();
      // Buck decrease in this turn
      const int numDecreased = min(k, maxNumFreq);
      k -= numDecreased;
      if (maxNumFreq > numDecreased)
        maxHeap.emplace(maxNum, maxNumFreq - numDecreased);
      if (!maxHeap.empty() && maxHeap.top().first + 1 == maxNum) {
        const auto [secondMaxNum, secondMaxNumFreq] = maxHeap.top();
        maxHeap.pop();
        maxHeap.emplace(secondMaxNum, secondMaxNumFreq + numDecreased);
      } else if (maxNum > 1) {
        maxHeap.emplace(maxNum - 1, numDecreased);
      }
    }

    long ans = 0;
    while (!maxHeap.empty()) {
      const auto [num, freq] = maxHeap.top();
      maxHeap.pop();
      ans += static_cast<long>(num) * num * freq;
    }

    return ans;
  }

 private:
  vector<int> getDiff(const vector<int>& nums1, const vector<int>& nums2) {
    vector<int> diff;
    for (int i = 0; i < nums1.size(); ++i)
      diff.push_back(abs(nums1[i] - nums2[i]));
    return diff;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Solution {
  public long minSumSquareDiff(int[] nums1, int[] nums2, int k1, int k2) {
    int[] diff = getDiff(nums1, nums2);
    int k = k1 + k2;
    if (Arrays.stream(diff).asLongStream().sum() <= k)
      return 0;

    Map<Integer, Integer> count = new HashMap<>();
    // (num, freq)
    Queue<Pair<Integer, Integer>> maxHeap =
        new PriorityQueue<>((a, b) -> b.getKey().compareTo(a.getKey()));

    for (final int d : diff)
      if (d != 0)
        count.merge(d, 1, Integer::sum);

    for (Map.Entry<Integer, Integer> entry : count.entrySet())
      maxHeap.offer(new Pair<>(entry.getKey(), entry.getValue()));

    while (k > 0) {
      Pair<Integer, Integer> pair = maxHeap.poll();
      final int maxNum = pair.getKey();
      final int maxNumFreq = pair.getValue();
      // Buck decrease in this turn
      final int numDecreased = Math.min(k, maxNumFreq);
      k -= numDecreased;
      if (maxNumFreq > numDecreased)
        maxHeap.offer(new Pair<>(maxNum, maxNumFreq - numDecreased));
      if (!maxHeap.isEmpty() && maxHeap.peek().getKey() + 1 == maxNum) {
        Pair<Integer, Integer> secondNode = maxHeap.poll();
        final int secondMaxNum = secondNode.getKey();
        final int secondMaxNumFreq = secondNode.getValue();
        maxHeap.offer(new Pair<>(secondMaxNum, secondMaxNumFreq + numDecreased));
      } else if (maxNum > 1) {
        maxHeap.offer(new Pair<>(maxNum - 1, numDecreased));
      }
    }

    long ans = 0;
    while (!maxHeap.isEmpty()) {
      Pair<Integer, Integer> pair = maxHeap.poll();
      final int num = pair.getKey();
      final int freq = pair.getValue();
      ans += (long) num * num * freq;
    }

    return ans;
  }

  private int[] getDiff(int[] nums1, int[] nums2) {
    int[] diff = new int[nums1.length];
    for (int i = 0; i < nums1.length; ++i)
      diff[i] = Math.abs(nums1[i] - nums2[i]);
    return diff;
  }
}