Skip to content

2867. Count Valid Paths in a Tree 👍

  • Time: $O(n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Solution {
 public:
  long long countPaths(int n, vector<vector<int>>& edges) {
    long ans = 0;
    const vector<bool> isPrime = sieveEratosthenes(n + 1);
    vector<vector<int>> graph(n + 1);

    for (const vector<int>& edge : edges) {
      const int u = edge[0];
      const int v = edge[1];
      graph[u].push_back(v);
      graph[v].push_back(u);
    }

    dfs(graph, 1, /*prev=*/-1, isPrime, ans);
    return ans;
  }

 private:
  pair<long, long> dfs(const vector<vector<int>>& graph, int u, int prev,
                       const vector<bool>& isPrime, long& ans) {
    long countZeroPrimePath = !isPrime[u];
    long countOnePrimePath = isPrime[u];

    for (const int v : graph[u]) {
      if (v == prev)
        continue;
      const auto& [countZeroPrimeChildPath, countOnePrimeChildPath] =
          dfs(graph, v, u, isPrime, ans);
      ans += countZeroPrimePath * countOnePrimeChildPath +
             countOnePrimePath * countZeroPrimeChildPath;
      if (isPrime[u]) {
        countOnePrimePath += countZeroPrimeChildPath;
      } else {
        countZeroPrimePath += countZeroPrimeChildPath;
        countOnePrimePath += countOnePrimeChildPath;
      }
    }

    return {countZeroPrimePath, countOnePrimePath};
  }

  vector<bool> sieveEratosthenes(int n) {
    vector<bool> isPrime(n, true);
    isPrime[0] = false;
    isPrime[1] = false;
    for (int i = 2; i * i < n; ++i)
      if (isPrime[i])
        for (int j = i * i; j < n; j += i)
          isPrime[j] = false;
    return isPrime;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Solution {
  public long countPaths(int n, int[][] edges) {
    final boolean[] isPrime = sieveEratosthenes(n + 1);
    List<Integer>[] graph = new List[n + 1];

    for (int i = 1; i <= n; ++i)
      graph[i] = new ArrayList<>();

    for (int[] edge : edges) {
      final int u = edge[0];
      final int v = edge[1];
      graph[u].add(v);
      graph[v].add(u);
    }

    dfs(graph, 1, /*prev=*/-1, isPrime);
    return ans;
  }

  private long ans = 0;

  private Pair<Long, Long> dfs(List<Integer>[] graph, int u, int prev, boolean[] isPrime) {
    long countZeroPrimePath = isPrime[u] ? 0 : 1;
    long countOnePrimePath = isPrime[u] ? 1 : 0;

    for (final int v : graph[u]) {
      if (v == prev)
        continue;
      Pair<Long, Long> pair = dfs(graph, v, u, isPrime);
      final long countZeroPrimeChildPath = pair.getKey();
      final long countOnePrimeChildPath = pair.getValue();
      ans +=
          countZeroPrimePath * countOnePrimeChildPath + countOnePrimePath * countZeroPrimeChildPath;
      if (isPrime[u]) {
        countOnePrimePath += countZeroPrimeChildPath;
      } else {
        countZeroPrimePath += countZeroPrimeChildPath;
        countOnePrimePath += countOnePrimeChildPath;
      }
    }

    return new Pair<>(countZeroPrimePath, countOnePrimePath);
  }

  private boolean[] sieveEratosthenes(int n) {
    boolean[] isPrime = new boolean[n];
    Arrays.fill(isPrime, true);
    isPrime[0] = false;
    isPrime[1] = false;
    for (int i = 2; i * i < n; ++i)
      if (isPrime[i])
        for (int j = i * i; j < n; j += i)
          isPrime[j] = false;
    return isPrime;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Solution:
  def countPaths(self, n: int, edges: list[list[int]]) -> int:
    ans = 0
    isPrime = self._sieveEratosthenes(n + 1)
    graph = [[] for _ in range(n + 1)]

    for u, v in edges:
      graph[u].append(v)
      graph[v].append(u)

    def dfs(u: int, prev: int) -> tuple[int, int]:
      nonlocal ans
      countZeroPrimePath = int(not isPrime[u])
      countOnePrimePath = int(isPrime[u])

      for v in graph[u]:
        if v == prev:
          continue
        countZeroPrimeChildPath, countOnePrimeChildPath = dfs(v, u)
        ans += (countZeroPrimePath * countOnePrimeChildPath +
                countOnePrimePath * countZeroPrimeChildPath)
        if isPrime[u]:
          countOnePrimePath += countZeroPrimeChildPath
        else:
          countZeroPrimePath += countZeroPrimeChildPath
          countOnePrimePath += countOnePrimeChildPath

      return countZeroPrimePath, countOnePrimePath

    dfs(1, -1)
    return ans

  def _sieveEratosthenes(self, n: int) -> list[bool]:
    isPrime = [True] * n
    isPrime[0] = False
    isPrime[1] = False
    for i in range(2, int(n**0.5) + 1):
      if isPrime[i]:
        for j in range(i * i, n, i):
          isPrime[j] = False
    return isPrime