Skip to content

2040. Kth Smallest Product of Two Sorted Arrays 👍

  • Time: $O(|\texttt{nums1}||\texttt{nums2}| \cdot \log 10^10)$
  • Space: $O(|\texttt{nums1}| + |\texttt{nums2}|)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Solution {
 public:
  long long kthSmallestProduct(vector<int>& nums1, vector<int>& nums2,
                               long long k) {
    vector<int> A1;
    vector<int> A2;
    vector<int> B1;
    vector<int> B2;

    seperate(nums1, A1, A2);
    seperate(nums2, B1, B2);

    const long negCount = A1.size() * B2.size() + A2.size() * B1.size();
    int sign = 1;

    if (k > negCount) {
      k -= negCount;  //  Find (k - negCount)-th positive.
    } else {
      k = negCount - k + 1;  // Find (negCount - k + 1)-th abs(negative).
      sign = -1;
      swap(B1, B2);
    }

    long l = 0;
    long r = 1e10;

    while (l < r) {
      const long m = (l + r) / 2;
      if (numProductNoGreaterThan(A1, B1, m) +
              numProductNoGreaterThan(A2, B2, m) >=
          k)
        r = m;
      else
        l = m + 1;
    }

    return sign * l;
  }

 private:
  void seperate(const vector<int>& A, vector<int>& A1, vector<int>& A2) {
    for (const int a : A)
      if (a < 0)
        A1.push_back(-a);
      else
        A2.push_back(a);
    ranges::reverse(A1);  // Reverse to sort ascending
  }

  long numProductNoGreaterThan(const vector<int>& A, const vector<int>& B,
                               long m) {
    long count = 0;
    int j = B.size() - 1;
    // For each a, find the first index j s.t. a * B[j] <= m
    // So numProductNoGreaterThan m for this row will be j + 1
    for (const long a : A) {
      while (j >= 0 && a * B[j] > m)
        --j;
      count += j + 1;
    }
    return count;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Solution {
  public long kthSmallestProduct(int[] nums1, int[] nums2, long k) {
    List<Integer> A1 = new ArrayList<>();
    List<Integer> A2 = new ArrayList<>();
    List<Integer> B1 = new ArrayList<>();
    List<Integer> B2 = new ArrayList<>();

    seperate(nums1, A1, A2);
    seperate(nums2, B1, B2);

    final long negCount = A1.size() * B2.size() + A2.size() * B1.size();
    int sign = 1;

    if (k > negCount) {
      k -= negCount; //  Find (k - negCount)-th positive.
    } else {
      k = negCount - k + 1; // Find (negCount - k + 1)-th abs(negative).
      sign = -1;
      List<Integer> temp = B1;
      B1 = B2;
      B2 = temp;
    }

    long l = 0;
    long r = (long) 1e10;

    while (l < r) {
      final long m = (l + r) / 2;
      if (numProductNoGreaterThan(A1, B1, m) + numProductNoGreaterThan(A2, B2, m) >= k)
        r = m;
      else
        l = m + 1;
    }

    return sign * l;
  }

  private void seperate(int[] A, List<Integer> A1, List<Integer> A2) {
    for (final int a : A)
      if (a < 0)
        A1.add(-a);
      else
        A2.add(a);
    Collections.reverse(A1); // Reverse to sort ascending
  }

  private long numProductNoGreaterThan(List<Integer> A, List<Integer> B, long m) {
    long count = 0;
    int j = B.size() - 1;
    // For each a, find the first index j s.t. a * B[j] <= m
    // So numProductNoGreaterThan m for this row will be j + 1
    for (final long a : A) {
      while (j >= 0 && a * B.get(j) > m)
        --j;
      count += j + 1;
    }
    return count;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class Solution:
  def kthSmallestProduct(
      self,
      nums1: list[int],
      nums2: list[int],
      k: int,
  ) -> int:
    A1 = [-num for num in nums1 if num < 0][::-1]  # Reverse to sort ascending
    A2 = [num for num in nums1 if num >= 0]
    B1 = [-num for num in nums2 if num < 0][::-1]  # Reverse to sort ascending
    B2 = [num for num in nums2 if num >= 0]

    negCount = len(A1) * len(B2) + len(A2) * len(B1)

    if k > negCount:  # Find (k - negCount)-th positive
      k -= negCount
      sign = 1
    else:
      k = negCount - k + 1  # Find (negCount - k + 1)-th abs(negative).
      sign = -1
      B1, B2 = B2, B1

    def numProductNoGreaterThan(A: list[int], B: list[int], m: int) -> int:
      ans = 0
      j = len(B) - 1
      for i in range(len(A)):
        # For each A[i], find the first index j s.t. A[i] * B[j] <= m
        # So numProductNoGreaterThan m for this row will be j + 1
        while j >= 0 and A[i] * B[j] > m:
          j -= 1
        ans += j + 1
      return ans

    l = 0
    r = 10**10

    while l < r:
      m = (l + r) // 2
      if (numProductNoGreaterThan(A1, B1, m) +
              numProductNoGreaterThan(A2, B2, m) >= k):
        r = m
      else:
        l = m + 1

    return sign * l