Skip to content

1439. Find the Kth Smallest Sum of a Matrix With Sorted Rows 👍

  • Time: $O(|\texttt{mat}| \cdot k\log k)$
  • Space: $O(k)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
struct T {
  int i;
  int j;
  int sum;  // nums1[i] + nums2[j]
};

class Solution {
 public:
  int kthSmallest(vector<vector<int>>& mat, int k) {
    vector<int> row = mat[0];

    for (int i = 1; i < mat.size(); ++i)
      row = kSmallestPairSums(row, mat[i], k);

    return row.back();
  }

 private:
  // Similar to 373. Find K Pairs with Smallest Sums
  vector<int> kSmallestPairSums(vector<int>& nums1, vector<int>& nums2, int k) {
    vector<int> ans;
    auto compare = [&](const T& a, const T& b) { return a.sum > b.sum; };
    priority_queue<T, vector<T>, decltype(compare)> minHeap(compare);

    for (int i = 0; i < k && i < nums1.size(); ++i)
      minHeap.emplace(i, 0, nums1[i] + nums2[0]);

    while (!minHeap.empty() && ans.size() < k) {
      const auto [i, j, _] = minHeap.top();
      minHeap.pop();
      ans.push_back(nums1[i] + nums2[j]);
      if (j + 1 < nums2.size())
        minHeap.emplace(i, j + 1, nums1[i] + nums2[j + 1]);
    }

    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution {
  public int kthSmallest(int[][] mat, int k) {
    int[] row = mat[0];

    for (int i = 1; i < mat.length; ++i)
      row = kSmallestPairSums(row, mat[i], k);

    return row[k - 1];
  }

  private record T(int i, int j, int sum) {}

  // Similar to 373. Find K Pairs with Smallest Sums
  private int[] kSmallestPairSums(int[] nums1, int[] nums2, int k) {
    List<Integer> ans = new ArrayList<>();
    Queue<T> minHeap = new PriorityQueue<>((a, b) -> Integer.compare(a.sum, b.sum));

    for (int i = 0; i < k && i < nums1.length; ++i)
      minHeap.offer(new T(i, 0, nums1[i] + nums2[0]));

    while (!minHeap.isEmpty() && ans.size() < k) {
      final int i = minHeap.peek().i;
      final int j = minHeap.poll().j;
      ans.add(nums1[i] + nums2[j]);
      if (j + 1 < nums2.length)
        minHeap.offer(new T(i, j + 1, nums1[i] + nums2[j + 1]));
    }

    return ans.stream().mapToInt(Integer::intValue).toArray();
  }
}