Skip to content

3067. Count Pairs of Connectable Servers in a Weighted Tree Network πŸ‘ΒΆ

  • Time: O(n2)O(n^2)
  • Space: O(n)O(n)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class Solution {
 public:
  vector<int> countPairsOfConnectableServers(vector<vector<int>>& edges,
                                             int signalSpeed) {
    const int n = edges.size() + 1;
    vector<int> ans;
    vector<vector<pair<int, int>>> tree(n);

    for (const vector<int>& edge : edges) {
      const int u = edge[0];
      const int v = edge[1];
      const int w = edge[2];
      tree[u].emplace_back(v, w);
      tree[v].emplace_back(u, w);
    }

    for (int i = 0; i < n; ++i)
      ans.push_back(connectablePairsRootedAt(tree, i, signalSpeed));

    return ans;
  }

 private:
  // Returns the number of server pairs that are connectable through the server
  // `u`.
  int connectablePairsRootedAt(const vector<vector<pair<int, int>>>& tree,
                               int u, int signalSpeed) {
    int pairs = 0;
    int count = 0;
    for (const auto& [v, w] : tree[u]) {
      const int childCount = dfs(tree, v, u, w, signalSpeed);
      pairs += count * childCount;
      count += childCount;
    }
    return pairs;
  }

  // Returns the number of servers that are connectable throught the server `u`
  // (dist % signalSpeed == 0).
  int dfs(const vector<vector<pair<int, int>>>& tree, int u, int prev, int dist,
          int signalSpeed) {
    int count = 0;
    for (const auto& [v, w] : tree[u])
      if (v != prev)
        count += dfs(tree, v, u, dist + w, signalSpeed);
    return (dist % signalSpeed == 0 ? 1 : 0) + count;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Solution {
  public int[] countPairsOfConnectableServers(int[][] edges, int signalSpeed) {
    final int n = edges.length + 1;
    int[] ans = new int[n];
    List<Pair<Integer, Integer>>[] graph = new List[n];

    for (int i = 0; i < n; i++)
      graph[i] = new ArrayList<>();

    for (final int[] edge : edges) {
      final int u = edge[0];
      final int v = edge[1];
      final int w = edge[2];
      graph[u].add(new Pair<>(v, w));
      graph[v].add(new Pair<>(u, w));
    }

    for (int i = 0; i < n; ++i)
      ans[i] = connectablePairsRootedAt(graph, i, signalSpeed);

    return ans;
  }

  // Returns the number of server pairs that are connectable through the server
  // `u`.
  private int connectablePairsRootedAt(List<Pair<Integer, Integer>>[] graph, int u,
                                       int signalSpeed) {
    int pairs = 0;
    int count = 0;
    for (Pair<Integer, Integer> pair : graph[u]) {
      final int v = pair.getKey();
      final int w = pair.getValue();
      final int childCount = dfs(graph, v, u, w, signalSpeed);
      pairs += count * childCount;
      count += childCount;
    }
    return pairs;
  }

  // Returns the number of servers that are connectable throught the server `u`
  // (dist % signalSpeed == 0).
  private int dfs(List<Pair<Integer, Integer>>[] graph, int u, int prev, int dist,
                  int signalSpeed) {
    int count = 0;
    for (Pair<Integer, Integer> pair : graph[u]) {
      final int v = pair.getKey();
      final int w = pair.getValue();
      if (v != prev)
        count += dfs(graph, v, u, dist + w, signalSpeed);
    }
    return (dist % signalSpeed == 0 ? 1 : 0) + count;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution:
  def countPairsOfConnectableServers(
      self,
      edges: list[list[int]],
      signalSpeed: int,
  ) -> list[int]:
    n = len(edges) + 1
    tree = [[] for _ in range(n)]

    for u, v, w in edges:
      tree[u].append((v, w))
      tree[v].append((u, w))

    def connectablePairsRootedAt(u: int) -> int:
      pairs = 0
      count = 0
      for v, w in tree[u]:
        childCount = dfs(v, u, w)
        pairs += count * childCount
        count += childCount
      return pairs

    def dfs(u: int, prev: int, dist: int) -> int:
      return (int(dist % signalSpeed == 0) +
              sum(dfs(v, u, dist + w)
              for v, w in tree[u]
              if v != prev))

    return [connectablePairsRootedAt(i) for i in range(n)]
Was this page helpful?