class Solution:
def minTotalDistance(self, grid: list[list[int]]) -> int:
m = len(grid)
n = len(grid[0])
# i indices s.t. grid[i][j] == 1
I = [i for i in range(m) for j in range(n) if grid[i][j]]
# j indices s.t. grid[i][j] == 1
J = [j for j in range(n) for i in range(m) if grid[i][j]]
def minTotalDistance(grid: list[int]) -> int:
summ = 0
i = 0
j = len(grid) - 1
while i < j:
summ += grid[j] - grid[i]
i += 1
j -= 1
return summ
# sum(i - median(I)) + sum(j - median(J))
return minTotalDistance(I) + minTotalDistance(J)