Skip to content

1373. Maximum Sum BST in Binary Tree 👍

  • Time:
  • Space:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
struct T {
  bool isBST;
  int max;
  int min;
  int sum;
  T() : isBST(false) {}
  T(bool isBST, int max, int min, int sum)
      : isBST(isBST), max(max), min(min), sum(sum) {}
};

class Solution {
 public:
  int maxSumBST(TreeNode* root) {
    int ans = 0;
    traverse(root, ans);
    return ans;
  }

 private:
  T traverse(TreeNode* root, int& ans) {
    if (!root)
      return T(true, INT_MIN, INT_MAX, 0);

    T left = traverse(root->left, ans);
    T right = traverse(root->right, ans);

    if (!left.isBST || !right.isBST)
      return T();
    if (root->val <= left.max || root->val >= right.min)
      return T();

    // root is a valid BST
    const int sum = root->val + left.sum + right.sum;
    ans = max(ans, sum);
    return T(true, max(root->val, right.max), min(root->val, left.min), sum);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class T {
  public boolean isBST;
  public Integer max;
  public Integer min;
  public Integer sum;
  public T() {
    this.isBST = false;
    this.max = null;
    this.min = null;
    this.sum = null;
  }
  public T(boolean isBST, int max, int min, int sum) {
    this.isBST = isBST;
    this.max = max;
    this.min = min;
    this.sum = sum;
  }
}

class Solution {
  public int maxSumBST(TreeNode root) {
    traverse(root);
    return ans;
  }

  private int ans = 0;

  private T traverse(TreeNode root) {
    if (root == null)
      return new T(true, Integer.MIN_VALUE, Integer.MAX_VALUE, 0);

    T left = traverse(root.left);
    T right = traverse(root.right);

    if (!left.isBST || !right.isBST)
      return new T();
    if (root.val <= left.max || root.val >= right.min)
      return new T();

    // root is a valid BST
    final int sum = root.val + left.sum + right.sum;
    ans = Math.max(ans, sum);
    return new T(true, Math.max(root.val, right.max), Math.min(root.val, left.min), sum);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class T:
  def __init__(self, isBST: bool = False,
               max: Optional[int] = None,
               min: Optional[int] = None,
               sum: Optional[int] = None):
    self.isBST = isBST
    self.max = max
    self.min = min
    self.sum = sum


class Solution:
  def maxSumBST(self, root: Optional[TreeNode]) -> int:
    self.ans = 0

    def traverse(root: Optional[TreeNode]) -> T:
      if not root:
        return T(True, -math.inf, math.inf, 0)

      left: T = traverse(root.left)
      right: T = traverse(root.right)

      if not left.isBST or not right.isBST:
        return T()
      if root.val <= left.max or root.val >= right.min:
        return T()

      # root is a valid BST
      sum = root.val + left.sum + right.sum
      self.ans = max(self.ans, sum)
      return T(True, max(root.val, right.max), min(root.val, left.min), sum)

    traverse(root)
    return self.ans