Skip to content

973. K Closest Points to Origin 👍

Approach 1: Heap

  • Time: $O(n\log K)$
  • Space: $O(K)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution {
 public:
  vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
    vector<vector<int>> ans;
    auto compare = [&](const vector<int>& a, const vector<int>& b) {
      return squareDist(a) < squareDist(b);
    };
    priority_queue<vector<int>, vector<vector<int>>, decltype(compare)> maxHeap(
        compare);

    for (const vector<int>& point : points) {
      maxHeap.push(point);
      if (maxHeap.size() > k)
        maxHeap.pop();
    }

    while (!maxHeap.empty())
      ans.push_back(maxHeap.top()), maxHeap.pop();

    return ans;
  };

 private:
  int squareDist(const vector<int>& p) {
    return p[0] * p[0] + p[1] * p[1];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
  public int[][] kClosest(int[][] points, int k) {
    int[][] ans = new int[k][2];
    PriorityQueue<int[]> maxHeap =
        new PriorityQueue<>((a, b) -> Integer.compare(squareDist(b), squareDist(a)));

    for (int[] point : points) {
      maxHeap.offer(point);
      if (maxHeap.size() > k)
        maxHeap.poll();
    }

    int i = k;
    while (!maxHeap.isEmpty())
      ans[--i] = maxHeap.poll();

    return ans;
  }

  private int squareDist(int[] p) {
    return p[0] * p[0] + p[1] * p[1];
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
class Solution:
  def kClosest(self, points: list[list[int]], k: int) -> list[list[int]]:
    maxHeap = []

    for x, y in points:
      heapq.heappush(maxHeap, (- x * x - y * y, [x, y]))
      if len(maxHeap) > k:
        heapq.heappop(maxHeap)

    return [pair[1] for pair in maxHeap]

Approach 2: Quick Select

  • Time: $O(n) \to O(n^2)$
  • Space: $O(K)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution {
 public:
  vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
    quickSelect(points, 0, points.size() - 1, k);
    return {points.begin(), points.begin() + k};
  };

 private:
  void quickSelect(vector<vector<int>>& points, int l, int r, int k) {
    const vector<int> pivot = points[r];

    int nextSwapped = l;
    for (int i = l; i < r; ++i)
      if (squareDist(points[i]) <= squareDist(pivot))
        swap(points[nextSwapped++], points[i]);
    swap(points[nextSwapped], points[r]);

    const int count = nextSwapped - l + 1;  // the number of points <= pivot
    if (count == k)
      return;
    if (count > k)
      quickSelect(points, l, nextSwapped - 1, k);
    else
      quickSelect(points, nextSwapped + 1, r, k - count);
  }

  int squareDist(const vector<int>& p) {
    return p[0] * p[0] + p[1] * p[1];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution {
  public int[][] kClosest(int[][] points, int k) {
    quickSelect(points, 0, points.length - 1, k);
    return Arrays.copyOfRange(points, 0, k);
  }

  private void quickSelect(int[][] points, int l, int r, int k) {
    final int[] pivot = points[r];

    int nextSwapped = l;
    for (int i = l; i < r; ++i)
      if (squareDist(points[i]) <= squareDist(pivot))
        swap(points, nextSwapped++, i);
    swap(points, nextSwapped, r);

    final int count = nextSwapped - l + 1; // the number of points <= pivot
    if (count == k)
      return;
    if (count > k)
      quickSelect(points, l, nextSwapped - 1, k);
    else
      quickSelect(points, nextSwapped + 1, r, k - count);
  }

  private int squareDist(int[] p) {
    return p[0] * p[0] + p[1] * p[1];
  }

  private void swap(int[][] points, int i, int j) {
    final int[] temp = points[i];
    points[i] = points[j];
    points[j] = temp;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution:
  def kClosest(self, points: list[list[int]], k: int) -> list[list[int]]:
    def squareDist(p: list[int]) -> int:
      return p[0] * p[0] + p[1] * p[1]

    def quickSelect(l: int, r: int, k: int) -> None:
      pivot = points[r]

      nextSwapped = l
      for i in range(l, r):
        if squareDist(points[i]) <= squareDist(pivot):
          points[nextSwapped], points[i] = points[i], points[nextSwapped]
          nextSwapped += 1
      points[nextSwapped], points[r] = points[r], points[nextSwapped]

      count = nextSwapped - l + 1  the number of points <= pivot
      if count == k:
        return
      if count > k:
        quickSelect(l, nextSwapped - 1, k)
      else:
        quickSelect(nextSwapped + 1, r, k - count)

    quickSelect(0, len(points) - 1, k)
    return points[0:k]

Approach 3: Quick Select with random pivot

  • Time: $O(n)$ (average)
  • Space: $O(K)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Solution {
 public:
  vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
    quickSelect(points, 0, points.size() - 1, k);
    return {points.begin(), points.begin() + k};
  };

 private:
  void quickSelect(vector<vector<int>>& points, int l, int r, int k) {
    const int randIndex = rand() % (r - l + 1) + l;
    swap(points[randIndex], points[r]);
    const vector<int> pivot = points[r];

    int nextSwapped = l;
    for (int i = l; i < r; ++i)
      if (squareDist(points[i]) <= squareDist(pivot))
        swap(points[nextSwapped++], points[i]);
    swap(points[nextSwapped], points[r]);

    const int count = nextSwapped - l + 1;  // the number of points <= pivot
    if (count == k)
      return;
    if (count > k)
      quickSelect(points, l, nextSwapped - 1, k);
    else
      quickSelect(points, nextSwapped + 1, r, k - count);
  }

  int squareDist(const vector<int>& p) {
    return p[0] * p[0] + p[1] * p[1];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
  public int[][] kClosest(int[][] points, int k) {
    quickSelect(points, 0, points.length - 1, k);
    return Arrays.copyOfRange(points, 0, k);
  }

  private void quickSelect(int[][] points, int l, int r, int k) {
    final int randIndex = new Random().nextInt(r - l + 1) + l;
    swap(points, randIndex, r);
    final int[] pivot = points[r];

    int nextSwapped = l;
    for (int i = l; i < r; ++i)
      if (squareDist(points[i]) <= squareDist(pivot))
        swap(points, nextSwapped++, i);
    swap(points, nextSwapped, r);

    final int count = nextSwapped - l + 1; // the number of points <= pivot
    if (count == k)
      return;
    if (count > k)
      quickSelect(points, l, nextSwapped - 1, k);
    else
      quickSelect(points, nextSwapped + 1, r, k - count);
  }

  private int squareDist(int[] p) {
    return p[0] * p[0] + p[1] * p[1];
  }

  private void swap(int[][] points, int i, int j) {
    final int[] temp = points[i];
    points[i] = points[j];
    points[j] = temp;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution:
  def kClosest(self, points: list[list[int]], k: int) -> list[list[int]]:
    def squareDist(p: list[int]) -> int:
      return p[0] * p[0] + p[1] * p[1]

    def quickSelect(l: int, r: int, k: int) -> None:
      randIndex = random.randint(0, r - l + 1) + l
      points[randIndex], points[r] = points[r], points[randIndex]
      pivot = points[r]

      nextSwapped = l
      for i in range(l, r):
        if squareDist(points[i]) <= squareDist(pivot):
          points[nextSwapped], points[i] = points[i], points[nextSwapped]
          nextSwapped += 1
      points[nextSwapped], points[r] = points[r], points[nextSwapped]

      count = nextSwapped - l + 1  the number of points <= pivot
      if count == k:
        return
      if count > k:
        quickSelect(l, nextSwapped - 1, k)
      else:
        quickSelect(nextSwapped + 1, r, k - count)

    quickSelect(0, len(points) - 1, k)
    return points[0:k]


class Solution:
  def kClosest(self, points: list[list[int]], k: int) -> list[list[int]]:
    def squareDist(p: list[int]) -> int:
      return p[0] * p[0] + p[1] * p[1]

    def quickSelect(l: int, r: int, k: int) -> None:
      pivot = points[r]

      nextSwapped = l
      for i in range(l, r):
        if squareDist(points[i]) <= squareDist(pivot):
          points[nextSwapped], points[i] = points[i], points[nextSwapped]
          nextSwapped += 1
      points[nextSwapped], points[r] = points[r], points[nextSwapped]

      count = nextSwapped - l + 1  the number of points <= pivot
      if count == k:
        return
      if count > k:
        quickSelect(l, nextSwapped - 1, k)
      else:
        quickSelect(nextSwapped + 1, r, k - count)

    quickSelect(0, len(points) - 1, k)
    return points[0:k]