Skip to content

3241. Time Taken to Mark All Nodes 👍

  • Time: $O(n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
struct Node {
  int node = 0;  // the node number
  int time = 0;  // the time taken to mark the entire subtree rooted at the node
};

struct Top2 {
  // the direct child node, where the time taken to mark the entire subtree
  // rooted at the node is the maximum
  Node top1;
  // the direct child node, where the time taken to mark the entire subtree
  // rooted at the node is the second maximum
  Node top2;
};

class Solution {
 public:
  vector<int> timeTaken(vector<vector<int>>& edges) {
    const int n = edges.size() + 1;
    vector<int> ans(n);
    vector<vector<int>> tree(n);
    // dp[i] := the top two direct child nodes for subtree rooted at node i,
    // where each node contains the time taken to mark the entire subtree rooted
    // at the node itself
    vector<Top2> dp(n);

    for (const vector<int>& edge : edges) {
      const int u = edge[0];
      const int v = edge[1];
      tree[u].push_back(v);
      tree[v].push_back(u);
    }

    dfs(tree, 0, /*prev=*/-1, dp);
    reroot(tree, 0, /*prev=*/-1, /*maxTime=*/0, dp, ans);
    return ans;
  }

 private:
  // Return the time taken to mark node u.
  int getTime(int u) {
    return u % 2 == 0 ? 2 : 1;
  }

  // Performs a DFS traversal of the subtree rooted at node `u`, computes the
  // time taken to mark all nodes in the subtree, records the top two direct
  // child nodes, where the time taken to mark the subtree rooted at each of the
  // child nodes is maximized, and returns the top child node.
  //
  // These values are used later in the rerooting process.
  int dfs(const vector<vector<int>>& tree, int u, int prev, vector<Top2>& dp) {
    Node top1;
    Node top2;
    for (const int v : tree[u]) {
      if (v == prev)
        continue;
      const int time = dfs(tree, v, u, dp) + getTime(v);
      if (time >= top1.time) {
        top2 = top1;
        top1 = Node(v, time);
      } else if (time > top2.time) {
        top2 = Node(v, time);
      }
    }
    dp[u] = Top2(top1, top2);
    return top1.time;
  }

  // Reroots the tree at node `u` and updates the answer array, where `maxTime`
  // is the longest path that doesn't go through `u`'s subtree.
  void reroot(const vector<vector<int>>& tree, int u, int prev, int maxTime,
              const vector<Top2>& dp, vector<int>& ans) {
    ans[u] = max(maxTime, dp[u].top1.time);
    for (const int v : tree[u]) {
      if (v == prev)
        continue;
      const int newMaxTime =
          getTime(u) + max(maxTime, dp[u].top1.node == v ? dp[u].top2.time
                                                         : dp[u].top1.time);
      reroot(tree, v, u, newMaxTime, dp, ans);
    }
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class Solution {
  public int[] timeTaken(int[][] edges) {
    final int n = edges.length + 1;
    int[] ans = new int[n];
    List<Integer>[] tree = new List[n];
    // dp[i] := the top two direct child nodes for subtree rooted at node i,
    // where each node contains the time taken to mark the entire subtree rooted
    // at the node itself
    Top2[] dp = new Top2[n];

    for (int i = 0; i < n; ++i) {
      tree[i] = new ArrayList<>();
      dp[i] = new Top2();
    }

    for (int[] edge : edges) {
      final int u = edge[0];
      final int v = edge[1];
      tree[u].add(v);
      tree[v].add(u);
    }

    dfs(tree, 0, /*prev=*/-1, dp);
    reroot(tree, 0, /*prev=*/-1, /*maxTime=*/0, dp, ans);
    return ans;
  }

  private record Node(int node, int time) {
    Node() {
      this(0, 0);
    }
  }

  private record Top2(Node max1, Node max2) {
    Top2() {
      this(new Node(), new Node());
    }
  }

  // Return the time taken to mark node u.
  private int getTime(int u) {
    return u % 2 == 0 ? 2 : 1;
  }

  // Performs a DFS traversal of the subtree rooted at node `u`, computes the
  // time taken to mark all nodes in the subtree, records the top two direct
  // child nodes, where the time taken to mark the subtree rooted at each of the
  // child nodes is maximized, and returns the top child node.
  //
  // These values are used later in the rerooting process.
  private int dfs(List<Integer>[] tree, int u, int prev, Top2[] dp) {
    Node max1 = new Node();
    Node max2 = new Node();
    for (final int v : tree[u]) {
      if (v == prev)
        continue;
      final int time = dfs(tree, v, u, dp) + getTime(v);
      if (time >= max1.time()) {
        max2 = max1;
        max1 = new Node(v, time);
      } else if (time > max2.time()) {
        max2 = new Node(v, time);
      }
    }
    dp[u] = new Top2(max1, max2);
    return max1.time();
  }

  // Reroots the tree at node `u` and updates the answer array, where `maxTime`
  // is the longest path that doesn't go through `u`'s subtree.
  private void reroot(List<Integer>[] tree, int u, int prev, int maxTime, Top2[] dp, int[] ans) {
    ans[u] = Math.max(maxTime, dp[u].max1().time());
    for (final int v : tree[u]) {
      if (v == prev)
        continue;
      final int newMaxTime =
          getTime(u) +
          Math.max(maxTime, dp[u].max1().node() == v ? dp[u].max2().time() : dp[u].max1().time());
      reroot(tree, v, u, newMaxTime, dp, ans);
    }
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from dataclasses import dataclass


@dataclass
class Node:
  node: int = 0  # the node number
  time: int = 0  # the time taken to mark the entire subtree rooted at the node


class Top2:
  def __init__(self, top1: Node = Node(), top2: Node = Node()):
    # the direct child node, where the time taken to mark the entire subtree
    # rooted at the node is the maximum
    self.top1 = top1
    # the direct child node, where the time taken to mark the entire subtree
    # rooted at the node is the second maximum
    self.top2 = top2


class Solution:
  def timeTaken(self, edges: list[list[int]]) -> list[int]:
    n = len(edges) + 1
    ans = [0] * n
    tree = [[] for _ in range(n)]
    # dp[i] := the top two direct child nodes for subtree rooted at node i,
    # where each node contains the time taken to mark the entire subtree rooted
    # at the node itself
    dp = [Top2()] * n

    for u, v in edges:
      tree[u].append(v)
      tree[v].append(u)

    self._dfs(tree, 0, -1, dp)
    self._reroot(tree, 0, -1, 0, dp, ans)
    return ans

  def _getTime(self, u: int) -> int:
    """Returns the time taken to mark node u."""
    return 2 if u % 2 == 0 else 1

  def _dfs(
      self,
      tree: list[list[int]],
      u: int,
      prev: int,
      dp: list[Top2]
  ) -> int:
    """
    Performs a DFS traversal of the subtree rooted at node `u`, computes the
    time taken to mark all nodes in the subtree, records the top two direct
    child nodes, where the time taken to mark the subtree rooted at each of the
    child nodes is maximized, and returns the top child node.

    These values are used later in the rerooting process.
    """
    top1 = Node()
    top2 = Node()
    for v in tree[u]:
      if v == prev:
        continue
      time = self._dfs(tree, v, u, dp) + self._getTime(v)
      if time >= top1.time:
        top2 = top1
        top1 = Node(v, time)
      elif time > top2.time:
        top2 = Node(v, time)
    dp[u] = Top2(top1, top2)
    return top1.time

  def _reroot(
      self,
      tree: list[list[int]],
      u: int,
      prev: int,
      maxTime: int,
      dp: list[Top2],
      ans: list[int]
  ) -> None:
    """
    Reroots the tree at node `u` and updates the answer array, where `maxTime`
    is the longest path that doesn't go through `u`'s subtree.
    """
    ans[u] = max(maxTime, dp[u].top1.time)

    for v in tree[u]:
      if v == prev:
        continue
      newMaxTime = self._getTime(u) + max(
          maxTime,
          dp[u].top2.time if dp[u].top1.node == v else dp[u].top1.time
      )
      self._reroot(tree, v, u, newMaxTime, dp, ans)