Skip to content

3367. Maximize Sum of Weights after Edge Removals 👍

  • Time: $O(n\log k)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Solution {
 public:
  long long maximizeSumOfWeights(vector<vector<int>>& edges, int k) {
    const int n = edges.size() + 1;
    vector<vector<pair<int, int>>> graph(n);

    for (const vector<int>& edge : edges) {
      const int u = edge[0];
      const int v = edge[1];
      const int w = edge[2];
      graph[u].emplace_back(v, w);
      graph[v].emplace_back(u, w);
    }

    return dfs(graph, 0, -1, k).second;
  }

  // Returns
  // (the weight sum of the subtree rooted at u with at most k - 1 children,
  //  the weight sum of the subtree rooted at u with at most k children).
  pair<long, long> dfs(const vector<vector<pair<int, int>>>& graph, int u,
                       int prev, int k) {
    long weightSum = 0;
    priority_queue<long> diffs;

    for (const auto& [v, w] : graph[u]) {
      if (v == prev)
        continue;
      const auto [subK1, subK] = dfs(graph, v, u, k);
      weightSum += subK;
      // If picking (u, v) makes the sum larger, we should pick it.
      diffs.push(max(0L, subK1 - subK + w));
    }

    long topK1 = 0;
    long topK = 0;

    for (int i = 0; i < k && !diffs.empty(); ++i) {
      if (i < k - 1)
        topK1 += diffs.top();
      topK += diffs.top();
      diffs.pop();
    }

    return {weightSum + topK1, weightSum + topK};
  };
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solution {
  public long maximizeSumOfWeights(int[][] edges, int k) {
    final int n = edges.length + 1;
    List<Pair<Integer, Integer>>[] graph = new List[n];

    for (int i = 0; i < n; ++i)
      graph[i] = new ArrayList<>();

    for (int[] edge : edges) {
      final int u = edge[0];
      final int v = edge[1];
      final int w = edge[2];
      graph[u].add(new Pair<>(v, w));
      graph[v].add(new Pair<>(u, w));
    }

    return dfs(graph, 0, -1, k).getValue();
  }

  // Returns
  // (the weight sum of the subtree rooted at u with at most k - 1 children,
  //  the weight sum of the subtree rooted at u with at most k children).
  private Pair<Long, Long> dfs(List<Pair<Integer, Integer>>[] graph, int u, int prev, int k) {
    long weightSum = 0;
    Queue<Long> diffs = new PriorityQueue<>(Collections.reverseOrder());

    for (Pair<Integer, Integer> pair : graph[u]) {
      final int v = pair.getKey();
      final int w = pair.getValue();
      if (v == prev)
        continue;
      Pair<Long, Long> subResult = dfs(graph, v, u, k);
      final long subK1 = subResult.getKey();
      final long subK = subResult.getValue();
      weightSum += subK;
      // If picking (u, v) makes the sum larger, we should pick it.
      diffs.offer(Math.max(0L, subK1 - subK + w));
    }

    long topK1 = 0;
    long topK = 0;

    for (int i = 0; i < k && !diffs.isEmpty(); ++i) {
      if (i < k - 1)
        topK1 += diffs.peek();
      topK += diffs.poll();
    }

    return new Pair<>(weightSum + topK1, weightSum + topK);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution:
  def maximizeSumOfWeights(self, edges: list[list[int]], k: int) -> int:
    graph = [[] for _ in range(len(edges) + 1)]

    for u, v, w in edges:
      graph[u].append((v, w))
      graph[v].append((u, w))

    def dfs(u: int, prev: int) -> tuple[int, int]:
      """
      Returns
      (the weight sum of the subtree rooted at u with at most k - 1 children,
       the weight sum of the subtree rooted at u with at most k children).
      """
      weightSum = 0
      diffs = []
      for v, w in graph[u]:
        if v == prev:
          continue
        subK1, subK = dfs(v, u)
        weightSum += subK
        # If picking (u, v) makes the sum larger, we should pick it.
        diffs.append(max(0, subK1 - subK + w))
      return (weightSum + sum(heapq.nlargest(k - 1, diffs)),
              weightSum + sum(heapq.nlargest(k, diffs)))

    return dfs(0, -1)[1]