Skip to content

2852. Sum of Remoteness of All Cells 👍

  • Time: $O(n^2)$
  • Space: $O(n^2)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Solution {
 public:
  long long sumRemoteness(vector<vector<int>>& grid) {
    const long sum = accumulate(grid.begin(), grid.end(), 0L,
                                [](long subtotal, const vector<int>& row) {
      return subtotal + accumulate(row.begin(), row.end(), 0L,
                                   [](long subtotal, const int cell) {
        return subtotal + max(0, cell);
      });
    });

    long ans = 0;

    for (int i = 0; i < grid.size(); ++i)
      for (int j = 0; j < grid[0].size(); ++j)
        if (grid[i][j] > 0) {
          const auto [count, componentSum] = dfs(grid, i, j);
          ans += (sum - componentSum) * count;
        }

    return ans;
  }

 private:
  static constexpr int dirs[4][2] = {{0, 1}, {1, 0}, {0, -1}, {-1, 0}};

  // Returns the (count, componentSum) of the connected component that contains
  // (x, y).
  pair<int, long> dfs(vector<vector<int>>& grid, int i, int j) {
    if (i < 0 || i == grid.size() || j < 0 || j == grid[0].size())
      return {0, 0};
    if (grid[i][j] == -1)
      return {0, 0};

    int count = 1;
    long componentSum = grid[i][j];
    grid[i][j] = -1;  // Mark as visited.;

    for (const auto& [dx, dy] : dirs) {
      const int x = i + dx;
      const int y = j + dy;
      const auto [nextCount, nextComponentSum] = dfs(grid, x, y);
      count += nextCount;
      componentSum += nextComponentSum;
    }

    return {count, componentSum};
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class Solution {
  public long sumRemoteness(int[][] grid) {
    long ans = 0;
    long sum = 0;

    for (int[] row : grid)
      for (int cell : row)
        sum += Math.max(0, cell);

    for (int i = 0; i < grid.length; ++i)
      for (int j = 0; j < grid[0].length; ++j)
        if (grid[i][j] > 0) {
          long[] res = dfs(grid, i, j);
          final long count = res[0];
          final long componentSum = res[1];
          ans += (sum - componentSum) * count;
        }

    return ans;
  }

  private static final int[][] dirs = {{0, 1}, {1, 0}, {0, -1}, {-1, 0}};

  // Returns the (count, componentSum) of the connected component that contains
  // (x, y).
  private long[] dfs(int[][] grid, int i, int j) {
    if (i < 0 || i == grid.length || j < 0 || j == grid[0].length)
      return new long[] {0, 0};
    if (grid[i][j] == -1)
      return new long[] {0, 0};

    long count = 1;
    long componentSum = grid[i][j];
    grid[i][j] = -1; // Mark as visited.;

    for (int[] dir : dirs) {
      final int x = i + dir[0];
      final int y = j + dir[1];
      long[] nextRes = dfs(grid, x, y);
      final long nextCount = nextRes[0];
      final long nextComponentSum = nextRes[1];
      count += nextCount;
      componentSum += nextComponentSum;
    }

    return new long[] {count, componentSum};
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution:
  def sumRemoteness(self, grid: list[list[int]]) -> int:
    dirs = ((0, 1), (1, 0), (0, -1), (-1, 0))
    n = len(grid)
    summ = sum(max(0, cell) for row in grid for cell in row)
    ans = 0

    def dfs(i: int, j: int) -> tuple[int, int]:
      """
      Returns the (count, componentSum) of the connected component that contains
      (x, y).
      """
      if i < 0 or i == len(grid) or j < 0 or j == len(grid[0]):
        return (0, 0)
      if grid[i][j] == -1:
        return (0, 0)

      count = 1
      componentSum = grid[i][j]
      grid[i][j] = -1  # Mark as visited.

      for dx, dy in dirs:
        x = i + dx
        y = j + dy
        nextCount, nextComponentSum = dfs(x, y)
        count += nextCount
        componentSum += nextComponentSum

      return (count, componentSum)

    for i in range(n):
      for j in range(n):
        if grid[i][j] > 0:
          count, componentSum = dfs(i, j)
          ans += (summ - componentSum) * count

    return ans