Skip to content

1895. Largest Magic Square

  • Time: $O(mnk^2)$
  • Space: $O(mn)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class Solution {
 public:
  int largestMagicSquare(vector<vector<int>>& grid) {
    const int m = grid.size();
    const int n = grid[0].size();
    // prefixRow[i][j] := the sum of the first j numbers in the i-th row
    vector<vector<int>> prefixRow(m, vector<int>(n + 1));
    // prefixCol[i][j] := the sum of the first j numbers in the i-th column
    vector<vector<int>> prefixCol(n, vector<int>(m + 1));

    for (int i = 0; i < m; ++i)
      for (int j = 0; j < n; ++j) {
        prefixRow[i][j + 1] = prefixRow[i][j] + grid[i][j];
        prefixCol[j][i + 1] = prefixCol[j][i] + grid[i][j];
      }

    for (int k = min(m, n); k >= 2; --k)
      if (containsMagicSquare(grid, prefixRow, prefixCol, k))
        return k;

    return 1;
  }

 private:
  // Returns true if the grid contains any magic square of size k x k.
  bool containsMagicSquare(const vector<vector<int>>& grid,
                           const vector<vector<int>>& prefixRow,
                           const vector<vector<int>>& prefixCol, int k) {
    for (int i = 0; i + k - 1 < grid.size(); ++i)
      for (int j = 0; j + k - 1 < grid[0].size(); ++j)
        if (isMagicSquare(grid, prefixRow, prefixCol, i, j, k))
          return true;
    return false;
  }

  // Returns true if grid[i..i + k)[j..j + k) is a magic square.
  bool isMagicSquare(const vector<vector<int>>& grid,
                     const vector<vector<int>>& prefixRow,
                     const vector<vector<int>>& prefixCol, int i, int j,
                     int k) {
    int diag = 0;
    int antiDiag = 0;
    for (int d = 0; d < k; ++d) {
      diag += grid[i + d][j + d];
      antiDiag += grid[i + d][j + k - 1 - d];
    }
    if (diag != antiDiag)
      return false;
    for (int d = 0; d < k; ++d) {
      if (getSum(prefixRow, i + d, j, j + k - 1) != diag)
        return false;
      if (getSum(prefixCol, j + d, i, i + k - 1) != diag)
        return false;
    }
    return true;
  }

  // Returns sum(grid[i][l..r]) or sum(grid[l..r][i]).
  int getSum(const vector<vector<int>>& prefix, int i, int l, int r) {
    return prefix[i][r + 1] - prefix[i][l];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Solution {
  public int largestMagicSquare(int[][] grid) {
    final int m = grid.length;
    final int n = grid[0].length;
    // prefixRow[i][j] := the sum of the first j numbers in the i-th row
    int[][] prefixRow = new int[m][n + 1];
    // prefixCol[i][j] := the sum of the first j numbers in the i-th column
    int[][] prefixCol = new int[n][m + 1];

    for (int i = 0; i < m; ++i)
      for (int j = 0; j < n; ++j) {
        prefixRow[i][j + 1] = prefixRow[i][j] + grid[i][j];
        prefixCol[j][i + 1] = prefixCol[j][i] + grid[i][j];
      }

    for (int k = Math.min(m, n); k >= 2; --k)
      if (containsMagicSquare(grid, prefixRow, prefixCol, k))
        return k;

    return 1;
  }

  // Returns true if the grid contains any magic square of size k x k.
  private boolean containsMagicSquare(int[][] grid, int[][] prefixRow, int[][] prefixCol, int k) {
    for (int i = 0; i + k - 1 < grid.length; ++i)
      for (int j = 0; j + k - 1 < grid[0].length; ++j)
        if (isMagicSquare(grid, prefixRow, prefixCol, i, j, k))
          return true;
    return false;
  }

  // Returns true if grid[i..i + k)[j..j + k) is a magic square.
  private boolean isMagicSquare(int[][] grid, int[][] prefixRow, int[][] prefixCol, int i, int j,
                                int k) {
    int diag = 0;
    int antiDiag = 0;
    for (int d = 0; d < k; ++d) {
      diag += grid[i + d][j + d];
      antiDiag += grid[i + d][j + k - 1 - d];
    }
    if (diag != antiDiag)
      return false;
    for (int d = 0; d < k; ++d) {
      if (getSum(prefixRow, i + d, j, j + k - 1) != diag)
        return false;
      if (getSum(prefixCol, j + d, i, i + k - 1) != diag)
        return false;
    }
    return true;
  }

  // Returns sum(grid[i][l..r]) or sum(grid[l..r][i]).
  private int getSum(int[][] prefix, int i, int l, int r) {
    return prefix[i][r + 1] - prefix[i][l];
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Solution:
  def largestMagicSquare(self, grid: list[list[int]]) -> int:
    m = len(grid)
    n = len(grid[0])
    # prefixRow[i][j] := the sum of the first j numbers in the i-th row
    prefixRow = [[0] * (n + 1) for _ in range(m)]
    # prefixCol[i][j] := the sum of the first j numbers in the i-th column
    prefixCol = [[0] * (m + 1) for _ in range(n)]

    for i in range(m):
      for j in range(n):
        prefixRow[i][j + 1] = prefixRow[i][j] + grid[i][j]
        prefixCol[j][i + 1] = prefixCol[j][i] + grid[i][j]

    def isMagicSquare(i: int, j: int, k: int) -> bool:
      """Returns True if grid[i..i + k)[j..j + k) is a magic square."""
      diag, antiDiag = 0, 0
      for d in range(k):
        diag += grid[i + d][j + d]
        antiDiag += grid[i + d][j + k - 1 - d]
      if diag != antiDiag:
        return False
      for d in range(k):
        if self._getSum(prefixRow, i + d, j, j + k - 1) != diag:
          return False
        if self._getSum(prefixCol, j + d, i, i + k - 1) != diag:
          return False
      return True

    def containsMagicSquare(k: int) -> bool:
      """Returns True if the grid contains any magic square of size k x k."""
      for i in range(m - k + 1):
        for j in range(n - k + 1):
          if isMagicSquare(i, j, k):
            return True
      return False

    for k in range(min(m, n), 1, -1):
      if containsMagicSquare(k):
        return k

    return 1

  def _getSum(self, prefix: list[list[int]], i: int, l: int, r: int) -> int:
    """Returns sum(grid[i][l..r]) or sum(grid[l..r][i])."""
    return prefix[i][r + 1] - prefix[i][l]