Skip to content

2603. Collect Coins in a Tree 👍

  • Time: $O(n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Solution {
 public:
  int collectTheCoins(vector<int>& coins, vector<vector<int>>& edges) {
    const int n = coins.size();
    vector<unordered_set<int>> tree(n);
    queue<int> leavesToBeRemoved;

    for (const vector<int>& edge : edges) {
      const int u = edge[0];
      const int v = edge[1];
      tree[u].insert(v);
      tree[v].insert(u);
    }

    for (int i = 0; i < n; ++i) {
      int u = i;
      // Remove the leaves that don't have coins.
      while (tree[u].size() == 1 && coins[u] == 0) {
        const int v = *tree[u].begin();
        tree[u].clear();
        tree[v].erase(u);
        u = v;  // Walk up to its parent.
      }
      // After trimming leaves without coins, leaves with coins may satisfy
      // `leavesToBeRemoved`.
      if (tree[u].size() == 1)
        leavesToBeRemoved.push(u);
    }

    // Remove each remaining leaf node and its parent. The remaining nodes are
    // the ones that must be visited.
    for (int i = 0; i < 2; ++i)
      for (int sz = leavesToBeRemoved.size(); sz > 0; --sz) {
        const int u = leavesToBeRemoved.front();
        leavesToBeRemoved.pop();
        if (!tree[u].empty()) {
          const int v = *tree[u].begin();
          tree[u].clear();
          tree[v].erase(u);
          if (tree[v].size() == 1)
            leavesToBeRemoved.push(v);
        }
      }

    return accumulate(tree.begin(), tree.end(), 0,
                      [](int subtotal, const unordered_set<int>& children) {
      return subtotal + children.size();
    });
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class Solution {
  public int collectTheCoins(int[] coins, int[][] edges) {
    final int n = coins.length;
    Set<Integer>[] tree = new Set[n];
    Deque<Integer> leavesToBeRemoved = new ArrayDeque<>();

    for (int i = 0; i < n; ++i)
      tree[i] = new HashSet<>();

    for (int[] edge : edges) {
      final int u = edge[0];
      final int v = edge[1];
      tree[u].add(v);
      tree[v].add(u);
    }

    for (int i = 0; i < n; ++i) {
      int u = i;
      // Remove the leaves that don't have coins.
      while (tree[u].size() == 1 && coins[u] == 0) {
        final int v = tree[u].iterator().next();
        tree[u].clear();
        tree[v].remove(u);
        u = v; // Walk up to its parent.
      }
      // After trimming leaves without coins, leaves with coins may satisfy
      // `leavesToBeRemoved`.
      if (tree[u].size() == 1)
        leavesToBeRemoved.offer(u);
    }

    // Remove each remaining leaf node and its parent. The remaining nodes are
    // the ones that must be visited.
    for (int i = 0; i < 2; ++i)
      for (int sz = leavesToBeRemoved.size(); sz > 0; --sz) {
        final int u = leavesToBeRemoved.poll();
        if (!tree[u].isEmpty()) {
          final int v = tree[u].iterator().next();
          tree[u].clear();
          tree[v].remove(u);
          if (tree[v].size() == 1)
            leavesToBeRemoved.offer(v);
        }
      }

    return Arrays.stream(tree).mapToInt(children -> children.size()).sum();
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution:
  def collectTheCoins(self, coins: list[int], edges: list[list[int]]) -> int:
    n = len(coins)
    tree = [set() for _ in range(n)]
    leavesToBeRemoved = collections.deque()

    for u, v in edges:
      tree[u].add(v)
      tree[v].add(u)

    for u in range(n):
      # Remove the leaves that don't have coins.
      while len(tree[u]) == 1 and coins[u] == 0:
        v = tree[u].pop()
        tree[v].remove(u)
        u = v  # Walk up to its parent.
      # After trimming leaves without coins, leaves with coins may satisfy
      # `leavesToBeRemoved`.
      if len(tree[u]) == 1:  # coins[u] must be 1.
        leavesToBeRemoved.append(u)

    # Remove each remaining leaf node and its parent. The remaining nodes are
    # the ones that must be visited.
    for _ in range(2):
      for _ in range(len(leavesToBeRemoved)):
        u = leavesToBeRemoved.popleft()
        if tree[u]:
          v = tree[u].pop()
          tree[v].remove(u)
          if len(tree[v]) == 1:  # It's a leaf.
            leavesToBeRemoved.append(v)

    return sum(len(children) for children in tree)