Skip to content

1740. Find Distance in a Binary Tree 👍

  • Time: $O(h)$
  • Space: $O(h)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
 public:
  int findDistance(TreeNode* root, int p, int q) {
    TreeNode* lca = getLCA(root, p, q);
    return dist(lca, p) + dist(lca, q);
  }

 private:
  TreeNode* getLCA(TreeNode* root, int p, int q) {
    if (root == nullptr || root->val == p || root->val == q)
      return root;
    TreeNode* left = getLCA(root->left, p, q);
    TreeNode* right = getLCA(root->right, p, q);
    if (left != nullptr && right != nullptr)
      return root;
    return left == nullptr ? right : left;
  }

  int dist(TreeNode* lca, int target) {
    if (lca == nullptr)
      return 10000;
    if (lca->val == target)
      return 0;
    return 1 + min(dist(lca->left, target), dist(lca->right, target));
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
  public int findDistance(TreeNode root, int p, int q) {
    TreeNode lca = getLCA(root, p, q);
    return dist(lca, p) + dist(lca, q);
  }

  private TreeNode getLCA(TreeNode root, int p, int q) {
    if (root == null || root.val == p || root.val == q)
      return root;

    TreeNode l = getLCA(root.left, p, q);
    TreeNode r = getLCA(root.right, p, q);

    if (l != null && r != null)
      return root;
    return l == null ? r : l;
  }

  private int dist(TreeNode lca, int target) {
    if (lca == null)
      return 10000;
    if (lca.val == target)
      return 0;
    return 1 + Math.min(dist(lca.left, target), dist(lca.right, target));
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution:
  def findDistance(self, root: TreeNode, p: int, q: int) -> int:
    def getLCA(root, p, q):
      if not root or root.val == p or root.val == q:
        return root

      l = getLCA(root.left, p, q)
      r = getLCA(root.right, p, q)

      if l and r:
        return root
      return l or r

    def dist(lca, target):
      if not lca:
        return 10000
      if lca.val == target:
        return 0
      return 1 + min(dist(lca.left, target), dist(lca.right, target))

    lca = getLCA(root, p, q)
    return dist(lca, p) + dist(lca, q)