Skip to content

230. Kth Smallest Element in a BST 👍

  • Time: $O(n^2)$
  • Space: $O(h)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class Solution {
 public:
  int kthSmallest(TreeNode* root, int k) {
    const int leftCount = countNodes(root->left);

    if (leftCount == k - 1)
      return root->val;
    if (leftCount >= k)
      return kthSmallest(root->left, k);
    return kthSmallest(root->right, k - 1 - leftCount);  // leftCount < k
  }

 private:
  int countNodes(TreeNode* root) {
    if (root == nullptr)
      return 0;
    return 1 + countNodes(root->left) + countNodes(root->right);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Solution {
  public int kthSmallest(TreeNode root, int k) {
    final int leftCount = countNodes(root.left);

    if (leftCount == k - 1)
      return root.val;
    if (leftCount >= k)
      return kthSmallest(root.left, k);
    return kthSmallest(root.right, k - 1 - leftCount); // leftCount < k
  }

  private int countNodes(TreeNode root) {
    if (root == null)
      return 0;
    return 1 + countNodes(root.left) + countNodes(root.right);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class Solution:
  def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
    def countNodes(root: Optional[TreeNode]) -> int:
      if not root:
        return 0
      return 1 + countNodes(root.left) + countNodes(root.right)

    leftCount = countNodes(root.left)

    if leftCount == k - 1:
      return root.val
    if leftCount >= k:
      return self.kthSmallest(root.left, k)
    return self.kthSmallest(root.right, k - 1 - leftCount)  # leftCount < k

Approach 2: Inorder Traversal

  • Time: $O(n)$
  • Space: $O(h)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
 public:
  int kthSmallest(TreeNode* root, int k) {
    int ans = -1;
    int rank = 0;
    traverse(root, k, rank, ans);
    return ans;
  }

 private:
  void traverse(TreeNode* root, int k, int& rank, int& ans) {
    if (root == nullptr)
      return;

    traverse(root->left, k, rank, ans);
    if (++rank == k) {
      ans = root->val;
      return;
    }
    traverse(root->right, k, rank, ans);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
  public int kthSmallest(TreeNode root, int k) {
    traverse(root, k);
    return ans;
  }

  private int ans = -1;
  private int rank = 0;

  private void traverse(TreeNode root, int k) {
    if (root == null)
      return;

    traverse(root.left, k);
    if (++rank == k) {
      ans = root.val;
      return;
    }
    traverse(root.right, k);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution:
  def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
    rank = 0
    ans = 0

    def traverse(root: Optional[TreeNode]) -> None:
      nonlocal rank
      nonlocal ans
      if not root:
        return

      traverse(root.left)
      rank += 1
      if rank == k:
        ans = root.val
        return
      traverse(root.right)

    traverse(root)
    return ans

Approach 3: Stack

  • Time: $O(n)$
  • Space: $O(h)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
 public:
  int kthSmallest(TreeNode* root, int k) {
    stack<TreeNode*> stack;

    while (root) {
      stack.push(root);
      root = root->left;
    }

    for (int i = 0; i < k - 1; ++i) {
      root = stack.top(), stack.pop();
      root = root->right;
      while (root) {
        stack.push(root);
        root = root->left;
      }
    }

    return stack.top()->val;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
  public int kthSmallest(TreeNode root, int k) {
    Deque<TreeNode> stack = new ArrayDeque<>();

    while (root != null) {
      stack.push(root);
      root = root.left;
    }

    for (int i = 0; i < k - 1; ++i) {
      root = stack.pop();
      root = root.right;
      while (root != null) {
        stack.push(root);
        root = root.left;
      }
    }

    return stack.peek().val;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution:
  def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
    stack = []

    while root:
      stack.append(root)
      root = root.left

    for _ in range(k - 1):
      root = stack.pop()
      root = root.right
      while root:
        stack.append(root)
        root = root.left

    return stack[-1].val