Skip to content

99. Recover Binary Search Tree 👍

Approach 1: Recursive

  • Time: $O(n)$
  • Space: $O(h)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
 public:
  void recoverTree(TreeNode* root) {
    inorder(root);
    swap(x, y);
  }

 private:
  TreeNode* pred = nullptr;
  TreeNode* x = nullptr;  // the first wrong node
  TreeNode* y = nullptr;  // the second wrong node

  void inorder(TreeNode* root) {
    if (root == nullptr)
      return;

    inorder(root->left);

    if (pred && root->val < pred->val) {
      y = root;
      if (x == nullptr)
        x = pred;
      else
        return;
    }
    pred = root;

    inorder(root->right);
  }

  void swap(TreeNode* x, TreeNode* y) {
    const int temp = x->val;
    x->val = y->val;
    y->val = temp;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution {
  public void recoverTree(TreeNode root) {
    inorder(root);
    swap(x, y);
  }

  private TreeNode pred = null;
  private TreeNode x = null; // the first wrong node
  private TreeNode y = null; // the second wrong node

  private void inorder(TreeNode root) {
    if (root == null)
      return;

    inorder(root.left);

    if (pred != null && root.val < pred.val) {
      y = root;
      if (x == null)
        x = pred;
      else
        return;
    }
    pred = root;

    inorder(root.right);
  }

  private void swap(TreeNode x, TreeNode y) {
    final int temp = x.val;
    x.val = y.val;
    y.val = temp;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Solution:
  def recoverTree(self, root: TreeNode | None) -> None:
    def swap(x: TreeNode | None, y: TreeNode | None) -> None:
      temp = x.val
      x.val = y.val
      y.val = temp

    def inorder(root: TreeNode | None) -> None:
      if not root:
        return

      inorder(root.left)

      if self.pred and root.val < self.pred.val:
        self.y = root
        if not self.x:
          self.x = self.pred
        else:
          return
      self.pred = root

      inorder(root.right)

    inorder(root)
    swap(self.x, self.y)

  pred = None
  x = None  # the first wrong node
  y = None  # the second wrong node

Approach 2: Iterative (stack)

  • Time: $O(n)$
  • Space: $O(h)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
 public:
  void recoverTree(TreeNode* root) {
    TreeNode* pred = nullptr;
    TreeNode* x = nullptr;  // the first wrong node
    TreeNode* y = nullptr;  // the second wrong node

    stack<TreeNode*> stack;

    while (root != nullptr || !stack.empty()) {
      while (root != nullptr) {
        stack.push(root);
        root = root->left;
      }
      root = stack.top(), stack.pop();
      if (pred && root->val < pred->val) {
        y = root;
        if (x == nullptr)
          x = pred;
      }
      pred = root;
      root = root->right;
    }

    swap(x, y);
  }

  void swap(TreeNode* x, TreeNode* y) {
    const int temp = x->val;
    x->val = y->val;
    y->val = temp;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Solution {
  public void recoverTree(TreeNode root) {
    TreeNode pred = null;
    TreeNode x = null;
    TreeNode y = null;

    Deque<TreeNode> stack = new ArrayDeque<>();

    while (root != null || !stack.isEmpty()) {
      while (root != null) {
        stack.push(root);
        root = root.left;
      }
      root = stack.pop();
      if (pred != null && root.val < pred.val) {
        y = root;
        if (x == null)
          x = pred;
      }
      pred = root;
      root = root.right;
    }

    swap(x, y);
  }

  private void swap(TreeNode x, TreeNode y) {
    final int temp = x.val;
    x.val = y.val;
    y.val = temp;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution:
  def recoverTree(self, root: TreeNode | None) -> None:
    pred = None
    x = None  # the first wrong node
    y = None  # the second wrong node
    stack = []

    while root or stack:
      while root:
        stack.append(root)
        root = root.left
      root = stack.pop()
      if pred and root.val < pred.val:
        y = root
        if not x:
          x = pred
      pred = root
      root = root.right

    def swap(x: TreeNode | None, y: TreeNode | None) -> None:
      temp = x.val
      x.val = y.val
      y.val = temp

    swap(x, y)

Approach 3: Morris

  • Time: $O(n)$
  • Space: $O(1)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Solution {
 public:
  void recoverTree(TreeNode* root) {
    TreeNode* pred = nullptr;
    TreeNode* x = nullptr;  // the first wrong node
    TreeNode* y = nullptr;  // the second wrong node

    while (root != nullptr) {
      if (root->left != nullptr) {
        TreeNode* morrisPred = findPredecessor(root);
        if (morrisPred->right) {
          // The node has already been connected before.
          visit(root, pred, x, y);
          morrisPred->right = nullptr;  // Break the connection.
          root = root->right;
        } else {
          morrisPred->right = root;  // Connect it.
          root = root->left;
        }
      } else {
        visit(root, pred, x, y);
        root = root->right;
      }
    }

    swap(x, y);
  }

 private:
  TreeNode* findPredecessor(TreeNode* root) {
    TreeNode* pred = root->left;
    while (pred->right && pred->right != root)
      pred = pred->right;
    return pred;
  }

  void visit(TreeNode*& root, TreeNode*& pred, TreeNode*& x, TreeNode*& y) {
    if (pred && root->val < pred->val) {
      y = root;
      if (x == nullptr)
        x = pred;
    }
    pred = root;
  }

  void swap(TreeNode* x, TreeNode* y) {
    const int temp = x->val;
    x->val = y->val;
    y->val = temp;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution {
  public void recoverTree(TreeNode root) {
    TreeNode pred = null;
    TreeNode x = null; // the first wrong node
    TreeNode y = null; // the second wrong node

    while (root != null) {
      if (root.left == null) {
        // Start the main logic.
        if (pred != null && root.val < pred.val) {
          y = root;
          if (x == null)
            x = pred;
        }
        pred = root;
        // End of the main logic.
        root = root.right;
      } else {
        TreeNode morrisPred = findPredecessor(root);
        if (morrisPred.right == null) {
          morrisPred.right = root; // Connect it.
          root = root.left;
        } else {
          // The node has already been connected before.
          // Start the main logic.
          if (pred != null && root.val < pred.val) {
            y = root;
            if (x == null)
              x = pred;
          }
          pred = root;
          // End of the main logic.
          morrisPred.right = null; // Break the connection.
          root = root.right;
        }
      }
    }

    swap(x, y);
  }

  private TreeNode findPredecessor(TreeNode root) {
    TreeNode pred = root.left;
    while (pred.right != null && pred.right != root)
      pred = pred.right;
    return pred;
  }

  private void swap(TreeNode x, TreeNode y) {
    final int temp = x.val;
    x.val = y.val;
    y.val = temp;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class Solution:
  def recoverTree(self, root: TreeNode | None) -> None:
    pred = None
    x = None  # the first wrong node
    y = None  # the second wrong node

    def findPredecessor(root: TreeNode | None) -> TreeNode | None:
      pred = root.left
      while pred.right and pred.right != root:
        pred = pred.right
      return pred

    while root:
      if root.left:
        morrisPred = findPredecessor(root)
        if morrisPred.right:
          # The node has already been connected before.
          # Start the main logic.
          if pred and root.val < pred.val:
            y = root
            if not x:
              x = pred
          pred = root
          # End of the main logic
          morrisPred.right = None  # Break the connection.
          root = root.right
        else:
          morrisPred.right = root  # Connect it.
          root = root.left
      else:
        # Start the main logic.
        if pred and root.val < pred.val:
          y = root
          if not x:
            x = pred
        pred = root
        # End of the main logic.
        root = root.right

    def swap(x: TreeNode | None, y: TreeNode | None) -> None:
      temp = x.val
      x.val = y.val
      y.val = temp

    swap(x, y)