Skip to content

2646. Minimize the Total Price of the Trips 👍

  • Time: $O(n \cdot |\texttt{trips}|)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class Solution {
 public:
  int minimumTotalPrice(int n, vector<vector<int>>& edges, vector<int>& price,
                        vector<vector<int>>& trips) {
    vector<vector<int>> graph(n);

    for (const vector<int>& edge : edges) {
      const int u = edge[0];
      const int v = edge[1];
      graph[u].push_back(v);
      graph[v].push_back(u);
    }

    // count[i] := the number of times node i is traversed
    vector<int> count(n);

    for (const vector<int>& trip : trips) {
      const int start = trip[0];
      const int end = trip[1];
      vector<int> path;
      dfsCount(graph, start, /*prev=*/-1, end, count, path);
    }

    vector<vector<int>> mem(n, vector<int>(2, -1));
    return dfs(graph, 0, -1, price, count, false, mem);
  }

 private:
  void dfsCount(const vector<vector<int>>& graph, int u, int prev, int end,
                vector<int>& count, vector<int>& path) {
    path.push_back(u);
    if (u == end) {
      for (const int i : path)
        ++count[i];
      return;
    }
    for (const int v : graph[u])
      if (v != prev)
        dfsCount(graph, v, u, end, count, path);
    path.pop_back();
  }

  // Returns the minimum price sum for the i-th node, where its parent is
  // halved parent or not halved not.
  int dfs(const vector<vector<int>>& graph, int u, int prev,
          const vector<int>& price, const vector<int>& count, int parentHalved,
          vector<vector<int>>& mem) {
    if (mem[u][parentHalved] != -1)
      return mem[u][parentHalved];

    int sumWithFullNode = price[u] * count[u];
    for (const int v : graph[u])
      if (v != prev)
        sumWithFullNode += dfs(graph, v, u, price, count, false, mem);

    if (parentHalved)  // Can't halve this node if its parent was halved.
      return mem[u][parentHalved] = sumWithFullNode;

    int sumWithHalvedNode = (price[u] / 2) * count[u];
    for (const int v : graph[u])
      if (v != prev)
        sumWithHalvedNode += dfs(graph, v, u, price, count, true, mem);

    return mem[u][parentHalved] = min(sumWithFullNode, sumWithHalvedNode);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Solution {
  public int minimumTotalPrice(int n, int[][] edges, int[] price, int[][] trips) {
    List<Integer>[] graph = new List[n];

    for (int i = 0; i < n; i++)
      graph[i] = new ArrayList<>();

    for (int[] edge : edges) {
      final int u = edge[0];
      final int v = edge[1];
      graph[u].add(v);
      graph[v].add(u);
    }

    // count[i] := the number of times node i is traversed
    int[] count = new int[n];
    for (int[] trip : trips) {
      final int start = trip[0];
      final int end = trip[1];
      dfsCount(graph, start, -1, end, count, /*path=*/new ArrayList<>());
    }

    Integer[][] mem = new Integer[n][2];
    return dfs(graph, 0, -1, price, count, false, mem);
  }

  private void dfsCount(List<Integer>[] graph, int u, int prev, int end, int[] count,
                        List<Integer> path) {
    path.add(u);
    if (u == end) {
      for (final int i : path)
        ++count[i];
      return;
    }
    for (final int v : graph[u])
      if (v != prev)
        dfsCount(graph, v, u, end, count, path);
    path.remove(path.size() - 1);
  }

  // Returns the minimum price sum for the i-th node, where its parent is
  // halved parent or not halved not.
  private int dfs(List<Integer>[] graph, int u, int prev, int[] price, int[] count,
                  boolean parentHalved, Integer[][] mem) {
    if (mem[u][parentHalved ? 1 : 0] != null)
      return mem[u][parentHalved ? 1 : 0];

    int sumWithFullNode = price[u] * count[u];
    for (final int v : graph[u])
      if (v != prev)
        sumWithFullNode += dfs(graph, v, u, price, count, false, mem);

    if (parentHalved) // Can't halve this node if its parent was halved.
      return mem[u][1] = sumWithFullNode;

    int sumWithHalvedNode = (price[u] / 2) * count[u];
    for (int v : graph[u])
      if (v != prev)
        sumWithHalvedNode += dfs(graph, v, u, price, count, true, mem);

    return mem[u][parentHalved ? 1 : 0] = Math.min(sumWithFullNode, sumWithHalvedNode);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Solution:
  def minimumTotalPrice(self, n: int, edges: list[list[int]], price: list[int],
                        trips: list[list[int]]) -> int:
    graph = [[] for _ in range(n)]

    for u, v in edges:
      graph[u].append(v)
      graph[v].append(u)

    # count[i] := the number of times i is traversed
    count = [0] * n

    def dfsCount(u: int, prev: int, end: int, path: list[int]) -> None:
      path.append(u)
      if u == end:
        for i in path:
          count[i] += 1
        return
      for v in graph[u]:
        if v != prev:
          dfsCount(v, u, end,  path)
      path.pop()

    for start, end in trips:
      dfsCount(start, -1, end, [])

    @functools.lru_cache(None)
    def dfs(u: int, prev: int, parentHalved: bool) -> int:
      """
      Returns the minimum price sum for the i-th node, where its parent is
      halved parent or not halved not.
      """
      sumWithFullNode = price[u] * count[u] + sum(dfs(v, u, False)
                                                  for v in graph[u]
                                                  if v != prev)
      if parentHalved:  # Can't halve this node if its parent was halved.
        return sumWithFullNode
      sumWithHalvedNode = (price[u] // 2) * count[u] + sum(dfs(v, u, True)
                                                           for v in graph[u]
                                                           if v != prev)
      return min(sumWithFullNode, sumWithHalvedNode)

    return dfs(0, -1, False)