Skip to content

2973. Find Number of Coins to Place in Tree Nodes 👍

  • Time: $O(n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class ChildCost {
 public:
  ChildCost(int cost) {
    numNodes = 1;
    if (cost > 0)
      maxPosCosts.push_back(cost);
    else
      minNegCosts.push_back(cost);
  }

  void update(ChildCost childCost) {
    numNodes += childCost.numNodes;
    ranges::copy(childCost.maxPosCosts, back_inserter(maxPosCosts));
    ranges::copy(childCost.minNegCosts, back_inserter(minNegCosts));
    ranges::sort(maxPosCosts, greater<int>());
    ranges::sort(minNegCosts);
    maxPosCosts.resize(min(static_cast<int>(maxPosCosts.size()), 3));
    minNegCosts.resize(min(static_cast<int>(minNegCosts.size()), 2));
  }

  long maxProduct() {
    if (numNodes < 3)
      return 1;
    if (maxPosCosts.empty())
      return 0;
    long res = 0;
    if (maxPosCosts.size() == 3)
      res = static_cast<long>(maxPosCosts[0]) * maxPosCosts[1] * maxPosCosts[2];
    if (minNegCosts.size() == 2)
      res = max(res, static_cast<long>(minNegCosts[0]) * minNegCosts[1] *
                         maxPosCosts[0]);
    return res;
  }

 private:
  int numNodes;
  vector<int> maxPosCosts;
  vector<int> minNegCosts;
};

class Solution {
 public:
  vector<long long> placedCoins(vector<vector<int>>& edges, vector<int>& cost) {
    const int n = cost.size();
    vector<long long> ans(n);
    vector<vector<int>> tree(n);

    for (const vector<int>& edge : edges) {
      const int u = edge[0];
      const int v = edge[1];
      tree[u].push_back(v);
      tree[v].push_back(u);
    }

    dfs(tree, 0, /*prev=*/-1, cost, ans);
    return ans;
  }

 private:
  ChildCost dfs(const vector<vector<int>>& tree, int u, int prev,
                const vector<int>& cost, vector<long long>& ans) {
    ChildCost res(cost[u]);
    for (const int v : tree[u])
      if (v != prev)
        res.update(dfs(tree, v, u, cost, ans));
    ans[u] = res.maxProduct();
    return res;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class ChildCost {
  public ChildCost(int cost) {
    if (cost > 0)
      maxPosCosts.add(cost);
    else
      minNegCosts.add(cost);
  }

  public void update(ChildCost childCost) {
    numNodes += childCost.numNodes;
    maxPosCosts.addAll(childCost.maxPosCosts);
    minNegCosts.addAll(childCost.minNegCosts);
    maxPosCosts.sort(Comparator.reverseOrder());
    minNegCosts.sort(Comparator.naturalOrder());
    if (maxPosCosts.size() > 3)
      maxPosCosts = maxPosCosts.subList(0, 3);
    if (minNegCosts.size() > 2)
      minNegCosts = minNegCosts.subList(0, 2);
  }

  public long maxProduct() {
    if (numNodes < 3)
      return 1;
    if (maxPosCosts.isEmpty())
      return 0;
    long res = 0;
    if (maxPosCosts.size() == 3)
      res = (long) maxPosCosts.get(0) * maxPosCosts.get(1) * maxPosCosts.get(2);
    if (minNegCosts.size() == 2)
      res = Math.max(res, (long) minNegCosts.get(0) * minNegCosts.get(1) * maxPosCosts.get(0));
    return res;
  }

  private int numNodes = 1;
  private List<Integer> maxPosCosts = new ArrayList<>();
  private List<Integer> minNegCosts = new ArrayList<>();
}

class Solution {
  public long[] placedCoins(int[][] edges, int[] cost) {
    final int n = cost.length;
    long[] ans = new long[n];
    List<Integer>[] tree = new List[n];

    for (int i = 0; i < n; i++)
      tree[i] = new ArrayList<>();

    for (int[] edge : edges) {
      final int u = edge[0];
      final int v = edge[1];
      tree[u].add(v);
      tree[v].add(u);
    }

    dfs(tree, 0, /*prev=*/-1, cost, ans);
    return ans;
  }

  private ChildCost dfs(List<Integer>[] tree, int u, int prev, int[] cost, long[] ans) {
    ChildCost res = new ChildCost(cost[u]);
    for (final int v : tree[u])
      if (v != prev)
        res.update(dfs(tree, v, u, cost, ans));
    ans[u] = res.maxProduct();
    return res;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class ChildCost:
  def __init__(self, cost: int):
    self.numNodes = 1
    self.maxPosCosts = [cost] if cost > 0 else []
    self.minNegCosts = [cost] if cost < 0 else []

  def update(self, childCost: 'ChildCost') -> None:
    self.numNodes += childCost.numNodes
    self.maxPosCosts.extend(childCost.maxPosCosts)
    self.minNegCosts.extend(childCost.minNegCosts)
    self.maxPosCosts.sort(reverse=True)
    self.minNegCosts.sort()
    self.maxPosCosts = self.maxPosCosts[:3]
    self.minNegCosts = self.minNegCosts[:2]

  def maxProduct(self) -> int:
    if self.numNodes < 3:
      return 1
    if not self.maxPosCosts:
      return 0
    res = 0
    if len(self.maxPosCosts) == 3:
      res = self.maxPosCosts[0] * self.maxPosCosts[1] * self.maxPosCosts[2]
    if len(self.minNegCosts) == 2:
      res = max(res,
                self.minNegCosts[0] * self.minNegCosts[1] * self.maxPosCosts[0])
    return res


class Solution:
  def placedCoins(self, edges: list[list[int]], cost: list[int]) -> list[int]:
    n = len(cost)
    ans = [0] * n
    tree = [[] for _ in range(n)]

    for u, v in edges:
      tree[u].append(v)
      tree[v].append(u)

    def dfs(u: int, prev: int) -> None:
      res = ChildCost(cost[u])
      for v in tree[u]:
        if v != prev:
          res.update(dfs(v, u))
      ans[u] = res.maxProduct()
      return res

    dfs(0, -1)
    return ans