Skip to content

823. Binary Trees With Factors 👍

  • Time: $O(n^2)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution {
 public:
  int numFactoredBinaryTrees(vector<int>& arr) {
    constexpr int kMod = 1'000'000'007;
    const int n = arr.size();
    // dp[i] := # of binary trees with arr[i] as root
    vector<long> dp(n, 1);
    unordered_map<int, int> numToIndex;

    sort(begin(arr), end(arr));

    for (int i = 0; i < n; ++i)
      numToIndex[arr[i]] = i;

    for (int i = 0; i < n; ++i)  // arr[i] is root
      for (int j = 0; j < i; ++j)
        if (arr[i] % arr[j] == 0) {  // arr[j] is left subtree
          const int right = arr[i] / arr[j];
          if (numToIndex.count(right)) {
            dp[i] += dp[j] * dp[numToIndex[right]];
            dp[i] %= kMod;
          }
        }

    return accumulate(begin(dp), end(dp), 0L) % kMod;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution {
  public int numFactoredBinaryTrees(int[] arr) {
    final int kMod = 1_000_000_007;
    final int n = arr.length;
    // dp[i] := # of binary trees with arr[i] as root
    long[] dp = new long[n];
    Map<Integer, Integer> numToIndex = new HashMap<>();

    Arrays.sort(arr);
    Arrays.fill(dp, 1);

    for (int i = 0; i < n; ++i)
      numToIndex.put(arr[i], i);

    for (int i = 0; i < n; ++i) // arr[i] is root
      for (int j = 0; j < i; ++j)
        if (arr[i] % arr[j] == 0) { // arr[j] is left subtree
          final int right = arr[i] / arr[j];
          if (numToIndex.containsKey(right)) {
            dp[i] += dp[j] * dp[numToIndex.get(right)];
            dp[i] %= kMod;
          }
        }

    return (int) (Arrays.stream(dp).sum() % kMod);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class Solution:
  def numFactoredBinaryTrees(self, arr: List[int]) -> int:
    kMod = 1_000_000_007
    n = len(arr)
    # dp[i] := # Of binary trees with arr[i] as root
    dp = [1] * n
    arr.sort()
    numToIndex = {a: i for i, a in enumerate(arr)}

    for i, root in enumerate(arr):  # arr[i] is root
      for j in range(i):
        if root % arr[j] == 0:  # arr[j] is left subtree
          right = root // arr[j]
          if right in numToIndex:
            dp[i] += dp[j] * dp[numToIndex[right]]
            dp[i] %= kMod

    return sum(dp) % kMod