Skip to content

378. Kth Smallest Element in a Sorted Matrix 👍

Approach 1: Heap

  • Time: $O(x + k\log x)$, where $x = \min(n, k)$
  • Space: $O(x)$, where $x = \min(n, k)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
struct T {
  int i;
  int j;
  int num;  // matrix[i][j]
  T(int i, int j, int num) : i(i), j(j), num(num) {}
};

class Solution {
 public:
  int kthSmallest(vector<vector<int>>& matrix, int k) {
    auto compare = [&](const T& a, const T& b) { return a.num > b.num; };
    priority_queue<T, vector<T>, decltype(compare)> minHeap(compare);

    for (int i = 0; i < k && i < matrix.size(); ++i)
      minHeap.emplace(i, 0, matrix[i][0]);

    while (k-- > 1) {
      const auto [i, j, _] = minHeap.top();
      minHeap.pop();
      if (j + 1 < matrix[0].size())
        minHeap.emplace(i, j + 1, matrix[i][j + 1]);
    }

    return minHeap.top().num;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class T {
  public int i;
  public int j;
  public int num; // matrix[i][j]
  public T(int i, int j, int num) {
    this.i = i;
    this.j = j;
    this.num = num;
  }
}

class Solution {
  public int kthSmallest(int[][] matrix, int k) {
    Queue<T> minHeap = new PriorityQueue<>((a, b) -> a.num - b.num);

    for (int i = 0; i < k && i < matrix.length; ++i)
      minHeap.offer(new T(i, 0, matrix[i][0]));

    while (k-- > 1) {
      final int i = minHeap.peek().i;
      final int j = minHeap.poll().j;
      if (j + 1 < matrix[0].length)
        minHeap.offer(new T(i, j + 1, matrix[i][j + 1]));
    }

    return minHeap.peek().num;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution:
  def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
    minHeap = []  # (matrix[i][j], i, j)

    i = 0
    while i < k and i < len(matrix):
      heapq.heappush(minHeap, (matrix[i][0], i, 0))
      i += 1

    while k > 1:
      k -= 1
      _, i, j = heapq.heappop(minHeap)
      if j + 1 < len(matrix[0]):
        heapq.heappush(minHeap, (matrix[i][j + 1], i, j + 1))

    return minHeap[0][0]
  • Time: $O(n\log(\max - \min))$
  • Space: $O(1)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution {
 public:
  int kthSmallest(vector<vector<int>>& matrix, int k) {
    int l = matrix[0][0];
    int r = matrix.back().back();

    while (l < r) {
      const int m = l + (r - l) / 2;
      if (numsNoGreaterThan(matrix, m) >= k)
        r = m;
      else
        l = m + 1;
    }

    return l;
  }

 private:
  int numsNoGreaterThan(const vector<vector<int>>& matrix, int m) {
    int count = 0;
    int j = matrix[0].size() - 1;
    // for each row, find the first index j s.t. row[j] <= m
    // so numsNoGreaterThan m for this row will be j + 1
    for (const auto& row : matrix) {
      while (j >= 0 && row[j] > m)
        --j;
      count += j + 1;
    }
    return count;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution {
  public int kthSmallest(int[][] matrix, int k) {
    int l = matrix[0][0];
    int r = matrix[matrix.length - 1][matrix.length - 1];

    while (l < r) {
      final int m = l + (r - l) / 2;
      if (numsNoGreaterThan(matrix, m) >= k)
        r = m;
      else
        l = m + 1;
    }

    return l;
  }

  private int numsNoGreaterThan(int[][] matrix, int m) {
    int count = 0;
    int j = matrix[0].length - 1;
    // for each row, find the first index j s.t. row[j] <= m
    // so numsNoGreaterThan m for this row will be j + 1
    for (int[] row : matrix) {
      while (j >= 0 && row[j] > m)
        --j;
      count += j + 1;
    }
    return count;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
  def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
    l = matrix[0][0]
    r = matrix[-1][-1]

    def numsNoGreaterThan(m: int) -> int:
      count = 0
      j = len(matrix[0]) - 1
      # for each row, find the first index j s.t. row[j] <= m
      # so numsNoGreaterThan m for this row will be j + 1
      for row in matrix:
        while j >= 0 and row[j] > m:
          j -= 1
        count += j + 1
      return count

    while l < r:
      m = (l + r) // 2
      if numsNoGreaterThan(m) >= k:
        r = m
      else:
        l = m + 1

    return l