Skip to content

1973. Count Nodes Equal to Sum of Descendants 👍

  • Time: $O(n)$
  • Space: $O(h)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
struct T {
  long sum;
  int count;
};

class Solution {
 public:
  int equalToDescendants(TreeNode* root) {
    return dfs(root).count;
  }

 private:
  T dfs(TreeNode* root) {
    if (root == nullptr)
      return T{.sum = 0, .count = 0};
    T left = dfs(root->left);
    T right = dfs(root->right);
    return T{.sum = root->val + left.sum + right.sum,
             .count = left.count + right.count +
                      (root->val == left.sum + right.sum ? 1 : 0)};
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution {
  public int equalToDescendants(TreeNode root) {
    return dfs(root).count;
  }

  private T dfs(TreeNode root) {
    if (root == null)
      return new T(0, 0);
    T left = dfs(root.left);
    T right = dfs(root.right);
    return new T(root.val + left.sum + right.sum,
                 left.count + right.count + (root.val == left.sum + right.sum ? 1 : 0));
  }

  private record T(long sum, int count){};
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from dataclasses import dataclass


@dataclass(frozen=True)
class T:
  summ: int
  count: int


class Solution:
  def equalToDescendants(self, root: TreeNode | None) -> int:
    def dfs(root: TreeNode | None) -> T:
      if not root:
        return T(0, 0)
      left = dfs(root.left)
      right = dfs(root.right)
      return T(root.val + left.summ + right.summ,
               left.count + right.count +
               (1 if root.val == left.summ + right.summ else 0))

    return dfs(root).count