Skip to content

2322. Minimum Score After Removals on a Tree 👍

  • Time: $O(n^2)$
  • Space: $O(n^2)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class Solution {
 public:
  int minimumScore(vector<int>& nums, vector<vector<int>>& edges) {
    const int n = nums.size();
    const int xors = reduce(nums.begin(), nums.end(), 0, bit_xor());
    vector<int> subXors(nums);
    vector<vector<int>> tree(n);
    vector<unordered_set<int>> children(n);

    for (int i = 0; i < n; ++i)
      children[i].insert(i);

    for (const vector<int>& edge : edges) {
      const int u = edge[0];
      const int v = edge[1];
      tree[u].push_back(v);
      tree[v].push_back(u);
    }

    dfs(tree, 0, -1, subXors, children);

    int ans = INT_MAX;

    for (int i = 0; i < edges.size(); ++i) {
      int a = edges[i][0];
      int b = edges[i][1];
      if (children[a].contains(b))
        swap(a, b);
      for (int j = 0; j < i; ++j) {
        int c = edges[j][0];
        int d = edges[j][1];
        if (children[c].contains(d))
          swap(c, d);
        vector<int> cands;
        if (a != c && children[a].contains(c))
          cands = {subXors[c], subXors[a] ^ subXors[c], xors ^ subXors[a]};
        else if (a != c && children[c].contains(a))
          cands = {subXors[a], subXors[c] ^ subXors[a], xors ^ subXors[c]};
        else
          cands = {subXors[a], subXors[c], xors ^ subXors[a] ^ subXors[c]};
        ans = min(ans, ranges::max(cands) - ranges::min(cands));
      }
    }

    return ans;
  }

 private:
  pair<int, unordered_set<int>> dfs(const vector<vector<int>>& tree, int u,
                                    int prev, vector<int>& subXors,
                                    vector<unordered_set<int>>& children) {
    for (const int v : tree[u]) {
      if (v == prev)
        continue;
      const auto& [vXor, vChildren] = dfs(tree, v, u, subXors, children);
      subXors[u] ^= vXor;
      children[u].insert(vChildren.begin(), vChildren.end());
    }
    return {subXors[u], children[u]};
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class Solution {
  public int minimumScore(int[] nums, int[][] edges) {
    final int n = nums.length;
    final int xors = getXors(nums);
    int[] subXors = nums.clone();
    List<Integer>[] tree = new List[n];
    Set<Integer>[] children = new Set[n];

    for (int i = 0; i < n; ++i)
      tree[i] = new ArrayList<>();

    for (int i = 0; i < n; ++i)
      children[i] = new HashSet<>(Arrays.asList(i));

    for (int[] edge : edges) {
      final int u = edge[0];
      final int v = edge[1];
      tree[u].add(v);
      tree[v].add(u);
    }

    dfs(tree, 0, -1, subXors, children);

    int ans = Integer.MAX_VALUE;

    for (int i = 0; i < edges.length; ++i) {
      int a = edges[i][0];
      int b = edges[i][1];
      if (children[a].contains(b)) {
        final int temp = a;
        a = b;
        b = a;
      }
      for (int j = 0; j < i; ++j) {
        int c = edges[j][0];
        int d = edges[j][1];
        if (children[c].contains(d)) {
          final int temp = c;
          c = d;
          d = temp;
        }
        int[] cands;
        if (a != c && children[a].contains(c))
          cands = new int[] {subXors[c], subXors[a] ^ subXors[c], xors ^ subXors[a]};
        else if (a != c && children[c].contains(a))
          cands = new int[] {subXors[a], subXors[c] ^ subXors[a], xors ^ subXors[c]};
        else
          cands = new int[] {subXors[a], subXors[c], xors ^ subXors[a] ^ subXors[c]};
        ans = Math.min(ans, Arrays.stream(cands).max().getAsInt() -
                                Arrays.stream(cands).min().getAsInt());
      }
    }

    return ans;
  }

  private Pair<Integer, Set<Integer>> dfs(List<Integer>[] tree, int u, int prev, int[] subXors,
                                          Set<Integer>[] children) {
    for (final int v : tree[u]) {
      if (v == prev)
        continue;
      final Pair<Integer, Set<Integer>> pair = dfs(tree, v, u, subXors, children);
      final int vXor = pair.getKey();
      final Set<Integer> vChildren = pair.getValue();
      subXors[u] ^= vXor;
      for (final int child : vChildren)
        children[u].add(child);
    }
    return new Pair<>(subXors[u], children[u]);
  }

  private int getXors(int[] nums) {
    int xors = 0;
    for (final int num : nums)
      xors ^= num;
    return xors;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Solution:
  def minimumScore(self, nums: list[int], edges: list[list[int]]) -> int:
    n = len(nums)
    xors = functools.reduce(operator.xor, nums)
    subXors = nums[:]
    tree = [[] for _ in range(n)]
    children = [{i} for i in range(n)]

    for u, v in edges:
      tree[u].append(v)
      tree[v].append(u)

    def dfs(u: int, prev: int) -> tuple[int, set[int]]:
      for v in tree[u]:
        if v == prev:
          continue
        vXor, vChildren = dfs(v, u)
        subXors[u] ^= vXor
        children[u] |= vChildren
      return subXors[u], children[u]

    dfs(0, -1)

    ans = math.inf
    for i in range(len(edges)):
      a, b = edges[i]
      if b in children[a]:
        a, b = b, a
      for j in range(i):
        c, d = edges[j]
        if d in children[c]:
          c, d = d, c

        if c in children[a] and a != c:
          cands = [subXors[c], subXors[a] ^ subXors[c], xors ^ subXors[a]]
        elif a in children[c] and a != c:
          cands = [subXors[a], subXors[c] ^ subXors[a], xors ^ subXors[c]]
        else:
          cands = [subXors[a], subXors[c], xors ^ subXors[a] ^ subXors[c]]
        ans = min(ans, max(cands) - min(cands))

    return ans