Skip to content

2713. Maximum Strictly Increasing Cells in a Matrix 👍

  • Time: $O(mn\log mn)$
  • Space: $O(mn)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
 public:
  int maxIncreasingCells(vector<vector<int>>& mat) {
    const int m = mat.size();
    const int n = mat[0].size();
    // rows[i] := the maximum path length for the i-th row
    vector<int> rows(m);
    // cols[j] := the maximum path length for the j-th column
    vector<int> cols(n);
    unordered_map<int, vector<pair<int, int>>> valToIndices;
    // maxPathLength[i][j] := the maximum path length from mat[i][j]
    vector<vector<int>> maxPathLength(m, vector<int>(n));
    // Sort all the unique values in the matrix in non-increasing order.
    set<int, greater<>> decreasingSet;

    for (int i = 0; i < m; ++i)
      for (int j = 0; j < n; ++j) {
        valToIndices[mat[i][j]].emplace_back(i, j);
        decreasingSet.insert(mat[i][j]);
      }

    for (const int val : decreasingSet) {
      for (const auto& [i, j] : valToIndices[val])
        maxPathLength[i][j] = max(rows[i], cols[j]) + 1;
      for (const auto& [i, j] : valToIndices[val]) {
        rows[i] = max(rows[i], maxPathLength[i][j]);
        cols[j] = max(cols[j], maxPathLength[i][j]);
      }
    }

    return max(ranges::max(rows), ranges::max(cols));
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution {
  public int maxIncreasingCells(int[][] mat) {
    final int m = mat.length;
    final int n = mat[0].length;
    // rows[i] := the maximum path length for the i-th row
    int[] rows = new int[m];
    // cols[j] := the maximum path length for the j-th column
    int[] cols = new int[n];
    Map<Integer, ArrayList<Pair<Integer, Integer>>> valToIndices = new HashMap<>();
    // maxPathLength[i][j] := the maximum path length from mat[i][j]
    int[][] maxPathLength = new int[m][n];
    // Sort all the unique values in the matrix in non-increasing order.
    TreeSet<Integer> decreasingSet = new TreeSet<>(Comparator.reverseOrder());

    for (int i = 0; i < m; ++i)
      for (int j = 0; j < n; ++j) {
        final int val = mat[i][j];
        valToIndices.putIfAbsent(val, new ArrayList<>());
        valToIndices.get(val).add(new Pair<>(i, j));
        decreasingSet.add(val);
      }

    for (final int val : decreasingSet) {
      for (Pair<Integer, Integer> pair : valToIndices.get(val)) {
        final int i = pair.getKey();
        final int j = pair.getValue();
        maxPathLength[i][j] = Math.max(rows[i], cols[j]) + 1;
      }
      for (Pair<Integer, Integer> pair : valToIndices.get(val)) {
        final int i = pair.getKey();
        final int j = pair.getValue();
        rows[i] = Math.max(rows[i], maxPathLength[i][j]);
        cols[j] = Math.max(cols[j], maxPathLength[i][j]);
      }
    }

    return Math.max(Arrays.stream(rows).max().getAsInt(), //
                    Arrays.stream(cols).max().getAsInt());
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution:
  def maxIncreasingCells(self, mat: list[list[int]]) -> int:
    m = len(mat)
    n = len(mat[0])
    rows = [0] * m  # rows[i] := the maximum path length for the i-th row
    cols = [0] * n  # cols[j] := the maximum path length for the j-th column
    valToIndices = collections.defaultdict(list)
    # maxPathLength[i][j] := the maximum path length from mat[i][j]
    maxPathLength = [[0] * n for _ in range(m)]
    # Sort all the unique values in the matrix in non-increasing order.
    decreasingSet = set()

    for i in range(m):
      for j in range(n):
        val = mat[i][j]
        valToIndices[val].append((i, j))
        decreasingSet.add(val)

    for val in sorted(decreasingSet, reverse=True):
      for i, j in valToIndices[val]:
        maxPathLength[i][j] = max(rows[i], cols[j]) + 1
      for i, j in valToIndices[val]:
        rows[i] = max(rows[i], maxPathLength[i][j])
        cols[j] = max(cols[j], maxPathLength[i][j])

    return max(max(rows), max(cols))