Skip to content

1595. Minimum Cost to Connect Two Groups of Points 👍

  • Time: $O(m \cdot 2^n \cdot n)$
  • Space: $O(m \cdot 2^n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Solution {
 public:
  int connectTwoGroups(vector<vector<int>>& cost) {
    const int m = cost.size();
    const int n = cost[0].size();
    vector<vector<int>> mem(m, vector<int>(1 << n, INT_MAX));
    // minCosts[j] := the minimum cost of connecting group2's point j
    vector<int> minCosts(n);

    for (int j = 0; j < n; ++j) {
      int minCostIndex = 0;
      for (int i = 1; i < m; ++i)
        if (cost[i][j] < cost[minCostIndex][j])
          minCostIndex = i;
      minCosts[j] = cost[minCostIndex][j];
    }

    return connectTwoGroups(cost, 0, 0, minCosts, mem);
  }

 private:
  // Returns the minimum cost to connect group1's points[i..n) with group2's
  // points, where `mask` is the bitmask of the connected points in group2.
  int connectTwoGroups(const vector<vector<int>>& cost, int i, int mask,
                       const vector<int>& minCosts, vector<vector<int>>& mem) {
    if (i == cost.size()) {
      // All the points in group 1 are connected, so greedily assign the
      // minimum cost for the unconnected points of group2.
      int res = 0;
      for (int j = 0; j < cost[0].size(); ++j)
        if ((mask >> j & 1) == 0)
          res += minCosts[j];
      return res;
    }
    if (mem[i][mask] != INT_MAX)
      return mem[i][mask];

    for (int j = 0; j < cost[0].size(); ++j)
      mem[i][mask] =
          min(mem[i][mask],
              cost[i][j] +
                  connectTwoGroups(cost, i + 1, mask | 1 << j, minCosts, mem));

    return mem[i][mask];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Solution {
  public int connectTwoGroups(List<List<Integer>> cost) {
    final int m = cost.size();
    final int n = cost.get(0).size();
    Integer[][] mem = new Integer[m][1 << n];
    // minCosts[j] := the minimum cost of connecting group2's point j
    int[] minCosts = new int[n];

    for (int j = 0; j < n; ++j) {
      int minCostIndex = 0;
      for (int i = 1; i < m; ++i)
        if (cost.get(i).get(j) < cost.get(minCostIndex).get(j))
          minCostIndex = i;
      minCosts[j] = cost.get(minCostIndex).get(j);
    }

    return connectTwoGroups(cost, 0, 0, minCosts, mem);
  }

  // Returns the minimum cost to connect group1's points[i..n) with group2's
  // points, where `mask` is the bitmask of the connected points in group2.
  private int connectTwoGroups(List<List<Integer>> cost, int i, int mask, int[] minCosts,
                               Integer[][] mem) {
    if (i == cost.size()) {
      // All the points in group 1 are connected, so greedily assign the
      // minimum cost for the unconnected points of group2.
      int res = 0;
      for (int j = 0; j < cost.get(0).size(); ++j)
        if ((mask >> j & 1) == 0)
          res += minCosts[j];
      return res;
    }
    if (mem[i][mask] != null)
      return mem[i][mask];

    int res = Integer.MAX_VALUE;
    for (int j = 0; j < cost.get(0).size(); ++j)
      res = Math.min(res, cost.get(i).get(j) +
                              connectTwoGroups(cost, i + 1, mask | 1 << j, minCosts, mem));
    return mem[i][mask] = res;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution:
  def connectTwoGroups(self, cost: list[list[int]]) -> int:
    # minCosts[j] := the minimum cost of connecting group2's point j
    minCosts = [min(col) for col in zip(*cost)]

    @functools.lru_cache(None)
    def dp(i: int, mask: int) -> int:
      """
      Returns the minimum cost to connect group1's points[i..n) with group2's
      points, where `mask` is the bitmask of the connected points in group2.
      """
      if i == len(cost):
        # All the points in group 1 are connected, so greedily assign the
        # minimum cost for the unconnected points of group2.
        return sum(minCost for j, minCost in enumerate(minCosts)
                   if (mask >> j & 1) == 0)
      return min(cost[i][j] + dp(i + 1, mask | 1 << j)
                 for j in range(len(cost[0])))

    return dp(0, 0)