Skip to content

3256. Maximum Value Sum by Placing Three Rooks I 👍

  • Time: $O(mn)$
  • Space: $O(m + n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Solution {
 public:
  long long maximumValueSum(vector<vector<int>>& board) {
    const int m = board.size();
    const int n = board[0].size();
    long ans = LONG_MIN;
    using T = tuple<long, int, int>;
    vector<vector<T>> rows(m);  // [(val, i, j)]
    vector<vector<T>> cols(n);  // [(val, i, j)]
    set<T> rowSet;              // {(val, i, j)}
    set<T> colSet;              // {(val, i, j)}
    set<T> topNine;             // {(val, i, j)}

    for (int i = 0; i < m; ++i)
      for (int j = 0; j < n; ++j) {
        rows[i].emplace_back(board[i][j], i, j);
        cols[j].emplace_back(board[i][j], i, j);
      }

    auto getTop3 = [](vector<T>& row) -> vector<T> {
      partial_sort(row.begin(),
                   row.begin() + min(3, static_cast<int>(row.size())),
                   row.end(), greater<>());
      row.resize(min(3, (int)row.size()));
      return row;
    };

    for (vector<T>& row : rows) {
      row = getTop3(row);
      rowSet.insert(row.begin(), row.end());
    }

    for (vector<T>& col : cols) {
      col = getTop3(col);
      colSet.insert(col.begin(), col.end());
    }

    set_intersection(rowSet.begin(), rowSet.end(), colSet.begin(), colSet.end(),
                     inserter(topNine, topNine.begin()));

    // At least 9 positions are required on the board to place 3 rooks such that
    // none can attack another.
    if (topNine.size() > 9) {
      auto it = topNine.begin();
      advance(it, topNine.size() - 9);
      topNine.erase(topNine.begin(), it);
    }

    for (auto it1 = topNine.begin(); it1 != topNine.end(); ++it1)
      for (auto it2 = next(it1); it2 != topNine.end(); ++it2)
        for (auto it3 = next(it2); it3 != topNine.end(); ++it3) {
          const auto [val1, i1, j1] = *it1;
          const auto [val2, i2, j2] = *it2;
          const auto [val3, i3, j3] = *it3;
          if (i1 == i2 || i1 == i3 || i2 == i3 ||  //
              j1 == j2 || j1 == j3 || j2 == j3)
            continue;
          ans = max(ans, val1 + val2 + val3);
        }

    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class Solution {
  public long maximumValueSum(int[][] board) {
    final int m = board.length;
    final int n = board[0].length;
    long ans = Long.MIN_VALUE;
    List<int[]>[] rows = new ArrayList[m];
    List<int[]>[] cols = new ArrayList[n];
    Set<int[]> rowSet = new HashSet<>();
    Set<int[]> colSet = new HashSet<>();
    Set<int[]> boardSet = new HashSet<>();

    for (int i = 0; i < m; ++i)
      rows[i] = new ArrayList<>();

    for (int j = 0; j < n; ++j)
      cols[j] = new ArrayList<>();

    for (int i = 0; i < m; ++i)
      for (int j = 0; j < n; ++j) {
        int[] cell = new int[] {board[i][j], i, j};
        rows[i].add(cell);
        cols[j].add(cell);
      }

    Comparator<int[]> comparator = Comparator.comparingInt(a -> - a[0]);

    for (List<int[]> row : rows) {
      row.sort(comparator);
      rowSet.addAll(row.subList(0, Math.min(3, row.size())));
    }

    for (List<int[]> col : cols) {
      col.sort(comparator);
      colSet.addAll(col.subList(0, Math.min(3, col.size())));
    }

    boardSet.addAll(rowSet);
    boardSet.retainAll(colSet);

    // At least 9 positions are required on the board to place 3 rooks such that
    // none can attack another.
    List<int[]> topNine = new ArrayList<>(boardSet);
    topNine.sort(comparator);
    topNine = topNine.subList(0, Math.min(9, topNine.size()));

    for (int i = 0; i < topNine.size(); ++i)
      for (int j = i + 1; j < topNine.size(); ++j)
        for (int k = j + 1; k < topNine.size(); ++k) {
          int[] t1 = topNine.get(i);
          int[] t2 = topNine.get(j);
          int[] t3 = topNine.get(k);
          if (t1[1] == t2[1] || t1[1] == t3[1] || t2[1] == t3[1] || //
              t1[2] == t2[2] || t1[2] == t3[2] || t2[2] == t3[2])
            continue;
          ans = Math.max(ans, (long) t1[0] + t2[0] + t3[0]);
        }

    return ans;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Solution:
  def maximumValueSum(self, board: list[list[int]]) -> int:
    rows = [heapq.nlargest(3, [(val, i, j)
            for j, val in enumerate(row)])
            for i, row in enumerate(board)]
    cols = [heapq.nlargest(3, [(val, i, j)
            for i, val in enumerate(col)])
            for j, col in enumerate(zip(*board))]
    topNine = heapq.nlargest(9,
                             set(itertools.chain(*rows)) &
                             set(itertools.chain(*cols)))
    return max(
        (val1 + val2 + val3 for
         (val1, i1, j1),
         (val2, i2, j2),
         (val3, i3, j3) in (itertools.combinations(topNine, 3))
         if len({i1, i2, i3}) == 3 and len({j1, j2, j3}) == 3))