Skip to content

3446. Sort Matrix by Diagonals 👍

  • Time: $O(n^2)$
  • Space: $O(n^2)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
 public:
  vector<vector<int>> sortMatrix(vector<vector<int>>& grid) {
    const int n = grid.size();
    vector<vector<int>> ans(n, vector<int>(n));
    vector<vector<int>> diag(2 * n + 1);

    for (int i = 0; i < n; ++i)
      for (int j = 0; j < n; ++j)
        diag[i - j + n].push_back(grid[i][j]);

    for (int i = 0; i < 2 * n + 1; ++i)
      if (i < n)
        ranges::sort(diag[i], greater<int>());
      else
        ranges::sort(diag[i]);

    for (int i = 0; i < n; ++i)
      for (int j = 0; j < n; ++j)
        ans[i][j] = diag[i - j + n].back(), diag[i - j + n].pop_back();

    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution {
  public int[][] sortMatrix(int[][] grid) {
    final int n = grid.length;
    int[][] ans = new int[n][n];
    Map<Integer, List<Integer>> diag = new HashMap<>();

    for (int i = 0; i < n; ++i)
      for (int j = 0; j < n; ++j) {
        final int key = i - j;
        diag.putIfAbsent(key, new ArrayList<>());
        diag.get(key).add(grid[i][j]);
      }

    for (Map.Entry<Integer, List<Integer>> entry : diag.entrySet()) {
      List<Integer> values = entry.getValue();
      if (entry.getKey() < 0)
        Collections.sort(values, Collections.reverseOrder());
      else
        Collections.sort(values);
    }

    for (int i = 0; i < n; i++)
      for (int j = 0; j < n; j++) {
        final int key = i - j;
        ans[i][j] = diag.get(key).remove(diag.get(key).size() - 1);
      }

    return ans;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class Solution:
  def sortMatrix(self, grid: list[list[int]]) -> list[list[int]]:
    n = len(grid)
    ans = [[0] * n for _ in range(n)]
    diag = collections.defaultdict(list)

    for i, row in enumerate(grid):
      for j, num in enumerate(row):
        diag[i - j].append(num)

    for key in diag:
      diag[key].sort(reverse=key < 0)

    for i in range(n):
      for j in range(n):
        ans[i][j] = diag[i - j].pop()

    return ans