Skip to content

3359. Find Sorted Submatrices With Maximum Element at Most K

  • Time: $O(mn)$
  • Space: $O(mn)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
struct T {
  int subarrayWidth;
  int rowIndex;
  int accumulatedSubmatrices;
};

class Solution {
 public:
  long long countSubmatrices(vector<vector<int>>& grid, int k) {
    int m = grid.size();
    int n = grid[0].size();
    long ans = 0;
    // dp[i][j] := the number of valid subarrays ending in grid[i][j]
    vector<vector<int>> dp(m, vector<int>(n));
    // stacks[j] := the stack of valid
    // (subarray width, row index, number of accumulated submatrices) ending in
    // column j
    vector<stack<T>> stacks(n);

    for (int j = 0; j < n; ++j)
      stacks[j].emplace(0, -1, 0);

    for (int i = 0; i < m; ++i)
      for (int j = 0; j < n; ++j)
        if (grid[i][j] > k) {
          stacks[j] = stack<T>();
          stacks[j].emplace(0, i, 0);
        } else {
          dp[i][j] = 1;
          if (j > 0 && grid[i][j - 1] <= k && grid[i][j - 1] >= grid[i][j])
            // Extend the valid subarrays to the current number.
            dp[i][j] += dp[i][j - 1];
          const int width = dp[i][j];
          stack<T>& stack = stacks[j];
          // Remove subarray widths greater than the current width since they
          // will become invalid.
          while (!stack.empty() && width < stack.top().subarrayWidth)
            stack.pop();
          const int height = i - stack.top().rowIndex;
          const int newSubmatrices = width * height;
          const int accumulatedSubmatrices =
              stack.top().accumulatedSubmatrices + newSubmatrices;
          ans += accumulatedSubmatrices;
          stack.emplace(width, i, accumulatedSubmatrices);
        }

    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class Solution {
  public long countSubmatrices(int[][] grid, int k) {
    record T(int subarrayWidth, int rowIndex, int accumulatedSubmatrices) {}
    final int m = grid.length;
    final int n = grid[0].length;
    long ans = 0;
    // dp[i][j] := the number of valid subarrays ending in grid[i][j]
    int[][] dp = new int[m][n];
    // stacks[j] := the stack of valid
    // (subarray width, row index, number of accumulated submatrices) ending in
    // column j
    List<Deque<T>> stacks = new ArrayList<>(n);

    for (int j = 0; j < n; ++j) {
      Deque<T> stack = new ArrayDeque<>();
      stack.push(new T(0, -1, 0));
      stacks.add(stack);
    }

    for (int i = 0; i < m; ++i)
      for (int j = 0; j < n; ++j)
        if (grid[i][j] > k) {
          stacks.get(j).clear();
          stacks.get(j).push(new T(0, i, 0));
        } else {
          dp[i][j] = 1;
          if (j > 0 && grid[i][j - 1] <= k && grid[i][j - 1] >= grid[i][j])
            // Extend the valid subarrays to the current number.
            dp[i][j] += dp[i][j - 1];
          final int width = dp[i][j];
          // Remove subarray widths greater than the current width since they
          // will become invalid.
          Deque<T> stack = stacks.get(j);
          while (!stack.isEmpty() && width < stack.peek().subarrayWidth)
            stack.pop();
          final int height = i - stack.peek().rowIndex;
          final int newSubmatrices = width * height;
          final int accumulatedSubmatrices = stack.peek().accumulatedSubmatrices + newSubmatrices;
          ans += accumulatedSubmatrices;
          stack.push(new T(width, i, accumulatedSubmatrices));
        }

    return ans;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from dataclasses import dataclass


@dataclass(frozen=True)
class T:
  subarrayWidth: int
  rowIndex: int
  accumulatedSubmatrices: int


class Solution:
  def countSubmatrices(self, grid: list[list[int]], k: int) -> int:
    m = len(grid)
    n = len(grid[0])
    ans = 0
    # dp[i][j] := the number of valid subarrays ending in grid[i][j]
    dp = [[0] * n for _ in range(m)]
    # stacks[j] := the stack of valid
    # (subarray width, row index, number of accumulated submatrices) ending in
    # column j
    stacks: list[T] = [[T(0, -1, 0)] for _ in range(n)]

    for i, row in enumerate(grid):
      for j, num in enumerate(row):
        if num > k:
          stacks[j] = [T(0, i, 0)]
        else:
          dp[i][j] = 1
          if j > 0 and row[j - 1] <= k and row[j - 1] >= row[j]:
            # Extend the valid subarrays to the current number.
            dp[i][j] += dp[i][j - 1]
          width = dp[i][j]
          # Remove subarray widths greater than the current count since they
          # will become invalid.
          while stacks[j] and width < stacks[j][-1].subarrayWidth:
            stacks[j].pop()
          height = i - stacks[j][-1].rowIndex
          newSubmatrices = width * height
          accumulatedSubmatrices = (stacks[j][-1].accumulatedSubmatrices +
                                    newSubmatrices)
          ans += accumulatedSubmatrices
          stacks[j].append(T(width, i, accumulatedSubmatrices))

    return ans