Skip to content

3510. Minimum Pair Removal to Sort Array II 👍

  • Time: $O(n\log n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class Solution {
 public:
  int minimumPairRemoval(vector<int>& nums) {
    const int n = nums.size();
    int ans = 0;
    int inversionsCount = 0;
    vector<int> nextIndices(n);
    vector<int> prevIndices(n);
    vector<long> values(nums.begin(), nums.end());

    // Custom comparator for the set
    auto comp = [](const pair<long, int>& a, const pair<long, int>& b) {
      return a.first < b.first || (a.first == b.first && a.second < b.second);
    };
    set<pair<long, int>, decltype(comp)> pairSums(comp);

    for (int i = 0; i < n; ++i) {
      nextIndices[i] = i + 1;
      prevIndices[i] = i - 1;
    }

    for (int i = 0; i < n - 1; ++i)
      pairSums.insert({(long)nums[i] + nums[i + 1], i});

    for (int i = 0; i < n - 1; ++i)
      if (nums[i + 1] < nums[i])
        ++inversionsCount;

    while (inversionsCount > 0) {
      ++ans;
      auto smallestPair = *pairSums.begin();
      pairSums.erase(pairSums.begin());

      const long pairSum = smallestPair.first;
      const int currIndex = smallestPair.second;
      const int nextIndex = nextIndices[currIndex];
      const int prevIndex = prevIndices[currIndex];

      if (prevIndex >= 0) {
        const long oldPairSum = values[prevIndex] + values[currIndex];
        const long newPairSum = values[prevIndex] + pairSum;
        pairSums.erase({oldPairSum, prevIndex});
        pairSums.insert({newPairSum, prevIndex});
        if (values[prevIndex] > values[currIndex])
          --inversionsCount;
        if (values[prevIndex] > pairSum)
          ++inversionsCount;
      }

      if (values[nextIndex] < values[currIndex])
        --inversionsCount;

      const int nextNextIndex = (nextIndex < n) ? nextIndices[nextIndex] : n;
      if (nextNextIndex < n) {
        const long oldPairSum = values[nextIndex] + values[nextNextIndex];
        const long newPairSum = pairSum + values[nextNextIndex];
        pairSums.erase({oldPairSum, nextIndex});
        pairSums.insert({newPairSum, currIndex});
        if (values[nextNextIndex] < values[nextIndex])
          --inversionsCount;
        if (values[nextNextIndex] < pairSum)
          ++inversionsCount;
        prevIndices[nextNextIndex] = currIndex;
      }

      nextIndices[currIndex] = nextNextIndex;
      values[currIndex] = pairSum;
    }

    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class Solution {
  public int minimumPairRemoval(int[] nums) {
    final int n = nums.length;
    int ans = 0;
    int inversionsCount = 0;
    int[] nextIndices = new int[n];
    int[] prevIndices = new int[n];
    long[] values = Arrays.stream(nums).mapToLong(i -> i).toArray();
    TreeSet<Pair<Long, Integer>> pairSums =
        new TreeSet<>(Comparator.comparingLong(Pair<Long, Integer>::getKey)
                          .thenComparingInt(Pair<Long, Integer>::getValue));

    for (int i = 0; i < n; ++i) {
      nextIndices[i] = i + 1;
      prevIndices[i] = i - 1;
    }

    for (int i = 0; i < n - 1; ++i)
      pairSums.add(new Pair<>((long) nums[i] + nums[i + 1], i));

    for (int i = 0; i < n - 1; ++i)
      if (nums[i + 1] < nums[i])
        ++inversionsCount;

    while (inversionsCount > 0) {
      ++ans;
      Pair<Long, Integer> smallestPair = pairSums.pollFirst();
      final long pairSum = smallestPair.getKey();
      final int currIndex = smallestPair.getValue();
      final int nextIndex = nextIndices[currIndex];
      final int prevIndex = prevIndices[currIndex];
      if (prevIndex >= 0) {
        final long oldPairSum = values[prevIndex] + values[currIndex];
        final long newPairSum = values[prevIndex] + pairSum;
        pairSums.remove(new Pair<>(oldPairSum, prevIndex));
        pairSums.add(new Pair<>(newPairSum, prevIndex));
        if (values[prevIndex] > values[currIndex])
          --inversionsCount;
        if (values[prevIndex] > pairSum)
          ++inversionsCount;
      }

      if (values[nextIndex] < values[currIndex])
        --inversionsCount;

      final int nextNextIndex = (nextIndex < n) ? nextIndices[nextIndex] : n;
      if (nextNextIndex < n) {
        final long oldPairSum = values[nextIndex] + values[nextNextIndex];
        final long newPairSum = pairSum + values[nextNextIndex];
        pairSums.remove(new Pair<>(oldPairSum, nextIndex));
        pairSums.add(new Pair<>(newPairSum, currIndex));
        if (values[nextNextIndex] < values[nextIndex])
          --inversionsCount;
        if (values[nextNextIndex] < pairSum)
          ++inversionsCount;
        prevIndices[nextNextIndex] = currIndex;
      }

      nextIndices[currIndex] = nextNextIndex;
      values[currIndex] = pairSum;
    }

    return ans;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from sortedcontainers import SortedList


class Solution:
  def minimumPairRemoval(self, nums: list[int]) -> int:
    n = len(nums)
    ans = 0
    inversionsCount = sum(nums[i + 1] < nums[i] for i in range(n - 1))
    nextIndices = [i + 1 for i in range(n)]
    prevIndices = [i - 1 for i in range(n)]
    pairSums = SortedList((a + b, i)
                          for i, (a, b) in enumerate(itertools.pairwise(nums)))

    while inversionsCount > 0:
      ans += 1
      smallestPair = pairSums.pop(0)
      pairSum, currIndex = smallestPair
      nextIndex = nextIndices[currIndex]
      prevIndex = prevIndices[currIndex]

      if prevIndex >= 0:
        oldPairSum = nums[prevIndex] + nums[currIndex]
        newPairSum = nums[prevIndex] + pairSum
        pairSums.remove((oldPairSum, prevIndex))
        pairSums.add((newPairSum, prevIndex))
        if nums[prevIndex] > nums[currIndex]:
          inversionsCount -= 1
        if nums[prevIndex] > pairSum:
          inversionsCount += 1

      if nums[nextIndex] < nums[currIndex]:
        inversionsCount -= 1

      nextNextIndex = nextIndices[nextIndex] if nextIndex < n else n
      if nextNextIndex < n:
        oldPairSum = nums[nextIndex] + nums[nextNextIndex]
        newPairSum = pairSum + nums[nextNextIndex]
        pairSums.remove((oldPairSum, nextIndex))
        pairSums.add((newPairSum, currIndex))
        if nums[nextNextIndex] < nums[nextIndex]:
          inversionsCount -= 1
        if nums[nextNextIndex] < pairSum:
          inversionsCount += 1
        prevIndices[nextNextIndex] = currIndex

      nextIndices[currIndex] = nextNextIndex
      nums[currIndex] = pairSum

    return ans