Skip to content

3478. Choose K Elements With Maximum Sum 👍

  • Time: $O(n\log k)$
  • Space: $O(n + k)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
 public:
  vector<long long> findMaxSum(vector<int>& nums1, vector<int>& nums2, int k) {
    const int n = nums1.size();
    vector<long long> ans(n);
    vector<pair<int, int>> numAndIndexes;
    priority_queue<long long, vector<long long>, greater<long long>> minHeap;

    for (int i = 0; i < n; i++)
      numAndIndexes.emplace_back(nums1[i], i);

    ranges::sort(numAndIndexes);

    const int firstIndex = numAndIndexes[0].second;
    minHeap.push(nums2[firstIndex]);
    long sum = nums2[firstIndex];

    for (int i = 1; i < n; ++i) {
      const auto& [currNum, currIndex] = numAndIndexes[i];
      const auto& [prevNum, prevIndex] = numAndIndexes[i - 1];
      if (currNum == prevNum)
        ans[currIndex] = ans[prevIndex];
      else
        ans[currIndex] = sum;
      minHeap.push(nums2[currIndex]);
      sum += nums2[currIndex];
      if (minHeap.size() == k + 1)
        sum -= minHeap.top(), minHeap.pop();
    }

    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution {
  public long[] findMaxSum(int[] nums1, int[] nums2, int k) {
    final int n = nums1.length;
    long[] ans = new long[n];
    Pair<Integer, Integer>[] numAndIndexes = new Pair[n];
    Queue<Long> minHeap = new PriorityQueue<>();

    for (int i = 0; i < n; ++i)
      numAndIndexes[i] = new Pair<>(nums1[i], i);

    Arrays.sort(numAndIndexes, Comparator.comparingInt(Pair::getKey));

    final int firstIndex = numAndIndexes[0].getValue();
    minHeap.offer((long) nums2[firstIndex]);
    long sum = nums2[firstIndex];

    for (int i = 1; i < n; ++i) {
      final int currNum = numAndIndexes[i].getKey();
      final int currIndex = numAndIndexes[i].getValue();
      final int prevNum = numAndIndexes[i - 1].getKey();
      final int prevIndex = numAndIndexes[i - 1].getValue();
      if (currNum == prevNum)
        ans[currIndex] = ans[prevIndex];
      else
        ans[currIndex] = sum;
      minHeap.offer((long) nums2[currIndex]);
      sum += nums2[currIndex];
      if (minHeap.size() == k + 1)
        sum -= minHeap.poll();
    }

    return ans;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution:
  def findMaxSum(self, nums1: list[int], nums2: list[int], k: int) -> list[int]:
    ans = [0] * len(nums1)
    numAndIndexes = sorted([(num, i) for i, num in enumerate(nums1)])
    minHeap = []

    firstIndex = numAndIndexes[0][1]
    heapq.heappush(minHeap, nums2[firstIndex])
    summ = nums2[firstIndex]

    for (prevNum, prevIndex), (currNum, currIndex) in itertools.pairwise(numAndIndexes):
      if currNum == prevNum:
        ans[currIndex] = ans[prevIndex]
      else:
        ans[currIndex] = summ
      heapq.heappush(minHeap, nums2[currIndex])
      summ += nums2[currIndex]
      if len(minHeap) == k + 1:
        summ -= heapq.heappop(minHeap)

    return ans