Skip to content

2791. Count Paths That Can Form a Palindrome in a Tree 👍

  • Time:
  • Space:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Solution {
 public:
  long long countPalindromePaths(vector<int>& parent, string s) {
    // A valid (u, v) has at most 1 letter with odd frequency on its path. The
    // frequency of a letter on the u-v path is equal to the sum of its
    // frequencies on the root-u and root-v paths substract twice of its
    // frequency on the root-LCA(u, v) path. Considering only the parity
    // (even/odd), the part involving root-LCA(u, v) can be ignored, making it
    // possible to calculate both parts easily using a simple DFS.
    vector<vector<int>> tree(parent.size());

    for (int i = 1; i < parent.size(); ++i)
      tree[parent[i]].push_back(i);

    return dfs(tree, 0, 0, s, {{0, 1}});
  }

 private:
  // mask := 26 bits that represent the parity of each character in the alphabet
  // on the path from node 0 to node u
  long dfs(const vector<vector<int>>& tree, int u, int mask, const string& s,
           unordered_map<int, int>&& maskToCount) {
    long res = 0;
    if (u > 0) {
      mask ^= 1 << (s[u] - 'a');
      // Consider any u-v path with 1 bit set.
      for (int i = 0; i < 26; ++i)
        if (const auto it = maskToCount.find(mask ^ (1 << i));
            it != maskToCount.cend())
          res += it->second;
      // Consider u-v path with 0 bit set.
      res += maskToCount[mask ^ 0]++;
    }
    for (const int v : tree[u])
      res += dfs(tree, v, mask, s, std::move(maskToCount));
    return res;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Solution {
  public long countPalindromePaths(List<Integer> parent, String s) {
    // A valid (u, v) has at most 1 letter with odd frequency on its path. The
    // frequency of a letter on the u-v path is equal to the sum of its
    // frequencies on the root-u and root-v paths substract twice of its
    // frequency on the root-LCA(u, v) path. Considering only the parity
    // (even/odd), the part involving root-LCA(u, v) can be ignored, making it
    // possible to calculate both parts easily using a simple DFS.
    List<Integer>[] tree = new List[parent.size()];

    for (int i = 0; i < parent.size(); ++i)
      tree[i] = new ArrayList<>();

    for (int i = 1; i < parent.size(); ++i)
      tree[parent.get(i)].add(i);

    return dfs(tree, 0, 0, s, new HashMap<>(Map.of(0, 1)));
  }

  // mask := 26 bits that represent the parity of each character in the alphabet
  // on the path from node 0 to node u
  private long dfs(List<Integer>[] tree, int u, int mask, String s,
                   Map<Integer, Integer> maskToCount) {
    long res = 0;
    if (u > 0) {
      mask ^= 1 << (s.charAt(u) - 'a');
      // Consider any u-v path with 1 bit set.
      for (int i = 0; i < 26; ++i)
        if (maskToCount.containsKey(mask ^ (1 << i)))
          res += maskToCount.get(mask ^ (1 << i));
      // Consider u-v path with 0 bit set.
      res += maskToCount.getOrDefault(mask ^ 0, 0);
      maskToCount.merge(mask, 1, Integer::sum);
    }
    for (final int v : tree[u])
      res += dfs(tree, v, mask, s, maskToCount);
    return res;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution:
  def countPalindromePaths(self, parent: list[int], s: str) -> int:
    # A valid (u, v) has at most 1 letter with odd frequency on its path. The
    # frequency of a letter on the u-v path is equal to the sum of its
    # frequencies on the root-u and root-v paths substract twice of its
    # frequency on the root-LCA(u, v) path. Considering only the parity
    # (even/odd), the part involving root-LCA(u, v) can be ignored, making it
    # possible to calculate both parts easily using a simple DFS.
    tree = [[] for _ in parent]
    maskToCount = collections.Counter({0: 1})

    for i in range(1, len(parent)):
      tree[parent[i]].append(i)

    # mask := 26 bits that represent the parity of each character in the alphabet
    # on the path from node 0 to node u
    def dfs(u: int, mask: int) -> int:
      res = 0
      if u > 0:
        mask ^= 1 << (ord(s[u]) - ord('a'))
        # Consider any u-v path with 1 bit set.
        for i in range(26):
          res += maskToCount[mask ^ (1 << i)]
        # Consider u-v path with 0 bit set.
        res += maskToCount[mask ^ 0]
        maskToCount[mask] += 1
      for v in tree[u]:
        res += dfs(v, mask)
      return res

    return dfs(0, 0)