Skip to content

3313. Find the Last Marked Nodes in Tree 👍

  • Time: $O(n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
struct Node {
  int node = 0;  // the node number
  int time = 0;  // the time it got marked
};

struct Last2 {
  Node last1;  // the last marked node
  Node last2;  // the second last marked node
};

class Solution {
 public:
  // Similar to 3241. Time Taken to Mark All Nodes
  vector<int> lastMarkedNodes(vector<vector<int>>& edges) {
    const int n = edges.size() + 1;
    vector<int> ans(n);
    vector<vector<int>> tree(n);
    // dp[i] := the last marked two nodes for subtree rooted at node i, where
    // each node contains the time it got marked
    vector<Last2> dp(n);

    for (const vector<int>& edge : edges) {
      const int u = edge[0];
      const int v = edge[1];
      tree[u].push_back(v);
      tree[v].push_back(u);
    }

    dfs(tree, 0, /*prev=*/-1, dp);
    reroot(tree, 0, /*prev=*/-1, /*last=*/Node(), dp, ans);
    return ans;
  }

 private:
  // Performs a DFS traversal of the subtree rooted at node `u`, computes the
  // time taken to mark all nodes in the subtree, records the last two marked
  // nodes, and returns the last marked node.
  //
  // These values are used later in the rerooting process.
  Node dfs(const vector<vector<int>>& tree, int u, int prev,
           vector<Last2>& dp) {
    Node last1(u, 0);
    Node last2;
    for (const int v : tree[u]) {
      if (v == prev)
        continue;
      Node child = dfs(tree, v, u, dp);
      const int time = child.time + 1;
      if (time > last1.time) {
        last2 = last1;
        last1 = Node(child.node, time);
      } else if (time > last2.time) {
        last2 = Node(child.node, time);
      }
    }
    dp[u] = {last1, last2};
    return last1;
  }

  // Reroots the tree at node `u` and updates the answer array, where `last`
  // is the last marked node that doesn't go through `u`'s subtree.
  void reroot(const vector<vector<int>>& tree, int u, int prev,
              const Node& last, vector<Last2>& dp, vector<int>& ans) {
    ans[u] = last.time > dp[u].last1.time ? last.node : dp[u].last1.node;
    for (const int v : tree[u]) {
      if (v == prev)
        continue;
      Node newLast(last.node, last.time + 1);
      if (dp[u].last1.node == dp[v].last1.node) {
        const int alternativeTime = 1 + dp[u].last2.time;
        if (alternativeTime > newLast.time)
          newLast = Node(dp[u].last2.node, alternativeTime);
      } else {
        const int alternativeTime = 1 + dp[u].last1.time;
        if (alternativeTime > newLast.time)
          newLast = Node(dp[u].last1.node, alternativeTime);
      }
      reroot(tree, v, u, newLast, dp, ans);
    }
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class Solution {
  // Similar to 3241. Time Taken to Mark All Nodes
  public int[] lastMarkedNodes(int[][] edges) {
    final int n = edges.length + 1;
    int[] ans = new int[n];
    List<Integer>[] tree = new List[n];
    // dp[i] := the last marked two nodes for subtree rooted at node i, where
    // each node contains the time it got marked
    Last2[] dp = new Last2[n];

    for (int i = 0; i < n; ++i) {
      tree[i] = new ArrayList<>();
      dp[i] = new Last2();
    }

    for (int[] edge : edges) {
      final int u = edge[0];
      final int v = edge[1];
      tree[u].add(v);
      tree[v].add(u);
    }

    dfs(tree, 0, /*prev=*/-1, dp);
    reroot(tree, 0, /*prev=*/-1, /*last=*/new Node(), dp, ans);
    return ans;
  }

  private record Node(int node, int time) {
    Node() {
      this(0, 0);
    }
  }

  private record Last2(Node last1, Node last2) {
    Last2() {
      this(new Node(), new Node());
    }
  }

  // Performs a DFS traversal of the subtree rooted at node `u`, computes the
  // time taken to mark all nodes in the subtree, records the last two marked
  // nodes, and returns the last marked node.
  //
  // These values are used later in the rerooting process.
  private Node dfs(List<Integer>[] tree, int u, int prev, Last2[] dp) {
    Node last1 = new Node(u, 0);
    Node last2 = new Node();
    for (final int v : tree[u]) {
      if (v == prev)
        continue;
      Node child = dfs(tree, v, u, dp);
      final int time = child.time() + 1;
      if (time > last1.time) {
        last2 = last1;
        last1 = new Node(child.node(), time);
      } else if (time > last2.time) {
        last2 = new Node(child.node(), time);
      }
    }
    dp[u] = new Last2(last1, last2);
    return last1;
  }

  // Reroots the tree at node `u` and updates the answer array, where `last`
  // is the last marked node that doesn't go through `u`'s subtree.
  private void reroot(List<Integer>[] tree, int u, int prev, Node last, Last2[] dp, int[] ans) {
    ans[u] = last.time() > dp[u].last1().time() ? last.node() : dp[u].last1().node();
    for (final int v : tree[u]) {
      if (v == prev)
        continue;
      Node newLast = new Node(last.node(), last.time() + 1); // 1 := u -> v
      if (dp[u].last1().node() == dp[v].last1().node()) {
        final int alternativeTime = 1 + dp[u].last2().time();
        if (alternativeTime > newLast.time())
          newLast = new Node(dp[u].last2().node(), alternativeTime);
      } else {
        final int alternativeTime = 1 + dp[u].last1().time();
        if (alternativeTime > newLast.time())
          newLast = new Node(dp[u].last1().node(), alternativeTime);
      }
      reroot(tree, v, u, newLast, dp, ans);
    }
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from dataclasses import dataclass


@dataclass
class Node:
  node: int = 0  # the node number
  time: int = 0  # the time it got marked


class Last2:
  def __init__(self, last1: Node = Node(), last2: Node = Node()):
    self.last1 = last1  # the last marked node
    self.last2 = last2  # the second last marked node


class Solution:
  # Similar to 3241. Time Taken to Mark All Nodes
  def lastMarkedNodes(self, edges: list[list[int]]) -> list[int]:
    n = len(edges) + 1
    ans = [0] * n
    tree = [[] for _ in range(n)]
    # dp[i] := the last marked two nodes for subtree rooted at node i, where
    # each node contains the time it got marked
    dp = [Last2()] * n

    for u, v in edges:
      tree[u].append(v)
      tree[v].append(u)

    self._dfs(tree, 0, -1, dp)
    self._reroot(tree, 0, -1, Node(), dp, ans)
    return ans

  def _dfs(
      self,
      tree: list[list[int]],
      u: int,
      prev: int,
      dp: list[Last2]
  ) -> Node:
    """
    Performs a DFS traversal of the subtree rooted at node `u`, computes the
    time taken to mark all nodes in the subtree, records the last two marked
    nodes, and returns the last marked node.

    These values are used later in the rerooting process.
    """
    last1 = Node(u, 0)
    last2 = Node()
    for v in tree[u]:
      if v == prev:
        continue
      child = self._dfs(tree, v, u, dp)
      time = child.time + 1
      if time > last1.time:
        last2 = last1
        last1 = Node(child.node, time)
      elif time > last2.time:
        last2 = Node(child.node, time)
    dp[u] = Last2(last1, last2)
    return last1

  def _reroot(
      self,
      tree: list[list[int]],
      u: int,
      prev: int,
      last: Node,
      dp: list[list[int]],
      ans: list[int]
  ) -> None:
    """
    Reroots the tree at node `u` and updates the answer array, where `last`
    is the last marked node that doesn't go through `u`'s subtree.
    """
    ans[u] = last.node if last.time > dp[u].last1.time else dp[u].last1.node
    for v in tree[u]:
      if v == prev:
        continue
      newLast = Node(last.node, last.time + 1)
      if dp[u].last1.node == dp[v].last1.node:
        alternativeTime = 1 + dp[u].last2.time
        if alternativeTime > newLast.time:
          newLast = Node(dp[u].last2.node, alternativeTime)
      else:
        alternativeTime = 1 + dp[u].last1.time
        if alternativeTime > newLast.time:
          newLast = Node(dp[u].last1.node, alternativeTime)
      self._reroot(tree, v, u, newLast, dp, ans)