Skip to content

2736. Maximum Sum Queries 👍

  • Time: $O(\texttt{sort}(n) + \texttt{sort}(q) + (q + n)\log n)$
  • Space: $O(n + q)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
struct Pair {
  int x;
  int y;
};

struct IndexedQuery {
  int queryIndex;
  int minX;
  int minY;
};

class Solution {
 public:
  vector<int> maximumSumQueries(vector<int>& nums1, vector<int>& nums2,
                                vector<vector<int>>& queries) {
    const vector<Pair> pairs = getPairs(nums1, nums2);
    vector<int> ans(queries.size());
    vector<pair<int, int>> stack;  // [(y, x + y)]

    int pairsIndex = 0;
    for (const auto& [queryIndex, minX, minY] : getIndexedQueries(queries)) {
      while (pairsIndex < pairs.size() && pairs[pairsIndex].x >= minX) {
        const auto [x, y] = pairs[pairsIndex++];
        // x + y is a better candidate. Given that x is decreasing, the
        // condition "x + y >= stack.back().second" suggests that y is
        // relatively larger, thereby making it a better candidate.
        while (!stack.empty() && x + y >= stack.back().second)
          stack.pop_back();
        if (stack.empty() || y > stack.back().first)
          stack.emplace_back(y, x + y);
      }
      const auto it = ranges::lower_bound(stack, pair<int, int>{minY, INT_MIN});
      ans[queryIndex] = it == stack.end() ? -1 : it->second;
    }

    return ans;
  }

 private:
  vector<Pair> getPairs(const vector<int>& nums1, const vector<int>& nums2) {
    vector<Pair> pairs;
    for (int i = 0; i < nums1.size(); ++i)
      pairs.push_back({nums1[i], nums2[i]});
    ranges::sort(pairs, ranges::greater{},
                 [](const Pair& pair) { return pair.x; });
    return pairs;
  }

  vector<IndexedQuery> getIndexedQueries(const vector<vector<int>>& queries) {
    vector<IndexedQuery> indexedQueries;
    for (int i = 0; i < queries.size(); ++i)
      indexedQueries.push_back({i, queries[i][0], queries[i][1]});
    ranges::sort(indexedQueries,
                 [](const IndexedQuery& a, const IndexedQuery& b) {
      return a.minX > b.minX;
    });
    return indexedQueries;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Solution {
  public int[] maximumSumQueries(int[] nums1, int[] nums2, int[][] queries) {
    MyPair[] pairs = getPairs(nums1, nums2);
    IndexedQuery[] indexedQueries = getIndexedQueries(queries);
    int[] ans = new int[queries.length];
    List<Pair<Integer, Integer>> stack = new ArrayList<>(); // [(y, x + y)]

    int pairsIndex = 0;
    for (IndexedQuery indexedQuery : indexedQueries) {
      final int queryIndex = indexedQuery.queryIndex;
      final int minX = indexedQuery.minX;
      final int minY = indexedQuery.minY;
      while (pairsIndex < pairs.length && pairs[pairsIndex].x >= minX) {
        MyPair pair = pairs[pairsIndex++];
        // x + y is a better candidate. Given that x is decreasing, the
        // condition "x + y >=  stack.get(stack.size() - 1).getValue()" suggests
        // that y is relatively larger, thereby making it a better candidate.
        final int x = pair.x;
        final int y = pair.y;
        while (!stack.isEmpty() && x + y >= stack.get(stack.size() - 1).getValue())
          stack.remove(stack.size() - 1);
        if (stack.isEmpty() || y > stack.get(stack.size() - 1).getKey())
          stack.add(new Pair<>(y, x + y));
      }
      final int j = firstGreaterEqual(stack, minY);
      ans[queryIndex] = j == stack.size() ? -1 : stack.get(j).getValue();
    }

    return ans;
  }

  private record MyPair(int x, int y){};
  private record IndexedQuery(int queryIndex, int minX, int minY){};

  private int firstGreaterEqual(List<Pair<Integer, Integer>> A, int target) {
    int l = 0;
    int r = A.size();
    while (l < r) {
      final int m = (l + r) / 2;
      if (A.get(m).getKey() >= target)
        r = m;
      else
        l = m + 1;
    }
    return l;
  }

  private MyPair[] getPairs(int[] nums1, int[] nums2) {
    MyPair[] pairs = new MyPair[nums1.length];
    for (int i = 0; i < nums1.length; ++i)
      pairs[i] = new MyPair(nums1[i], nums2[i]);
    Arrays.sort(pairs, (a, b) -> Integer.compare(b.x, a.x));
    return pairs;
  }

  private IndexedQuery[] getIndexedQueries(int[][] queries) {
    IndexedQuery[] indexedQueries = new IndexedQuery[queries.length];
    for (int i = 0; i < queries.length; ++i)
      indexedQueries[i] = new IndexedQuery(i, queries[i][0], queries[i][1]);
    Arrays.sort(indexedQueries, (a, b) -> Integer.compare(b.minX, a.minX));
    return indexedQueries;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from dataclasses import dataclass


@dataclass(frozen=True)
class Pair:
  x: int
  y: int

  def __iter__(self):
    yield self.x
    yield self.y


@dataclass(frozen=True)
class IndexedQuery:
  queryIndex: int
  minX: int
  minY: int

  def __iter__(self):
    yield self.queryIndex
    yield self.minX
    yield self.minY


class Solution:
  def maximumSumQueries(
      self,
      nums1: list[int],
      nums2: list[int],
      queries: list[list[int]],
  ) -> list[int]:
    pairs = sorted([Pair(nums1[i], nums2[i])
                   for i in range(len(nums1))], key=lambda x: x.x, reverse=True)
    ans = [0] * len(queries)
    stack = []  # [(y, x + y)]

    pairsIndex = 0
    for queryIndex, minX, minY in sorted([IndexedQuery(i, query[0], query[1])
                                          for i, query in enumerate(queries)],
                                         key=lambda x: -x.minX):
      while pairsIndex < len(pairs) and pairs[pairsIndex].x >= minX:
        # x + y is a better candidate. Given that x is decreasing, the
        # condition "x + y >= stack[-1][1]" suggests that y is relatively
        # larger, thereby making it a better candidate.
        x, y = pairs[pairsIndex]
        while stack and x + y >= stack[-1][1]:
          stack.pop()
        if not stack or y > stack[-1][0]:
          stack.append((y, x + y))
        pairsIndex += 1
      j = self._firstGreaterEqual(stack, minY)
      ans[queryIndex] = -1 if j == len(stack) else stack[j][1]

    return ans

  def _firstGreaterEqual(self, A: list[tuple[int, int]], target: int) -> int:
    l = 0
    r = len(A)
    while l < r:
      m = (l + r) // 2
      if A[m][0] >= target:
        r = m
      else:
        l = m + 1
    return l