Skip to content

3414. Maximum Score of Non-overlapping Intervals 👍

  • Time: $O(n\log n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
struct T {
  long weight;
  vector<int> selected;
};

class Solution {
 public:
  vector<int> maximumWeight(vector<vector<int>>& input) {
    vector<Interval> intervals;
    for (int i = 0; i < input.size(); ++i)
      intervals.emplace_back(input[i][0], input[i][1], input[i][2], i);
    ranges::sort(intervals);
    vector<vector<T>> memo(intervals.size(), vector<T>(5));
    return dp(intervals, memo, 0, 4).selected;
  }

 private:
  using Interval = tuple<int, int, int, int>;

  T dp(const vector<Interval>& intervals, vector<vector<T>>& memo, int i,
       int quota) {
    if (i == intervals.size() || quota == 0)
      return T();
    if (memo[i][quota].weight > 0)
      return memo[i][quota];

    T skip = dp(intervals, memo, i + 1, quota);

    auto [_, r, weight, originalIndex] = intervals[i];
    const int j = findFirstGreater(intervals, i + 1, r);
    T nextRes = dp(intervals, memo, j, quota - 1);

    vector<int> newSelected = nextRes.selected;
    newSelected.push_back(originalIndex);
    ranges::sort(newSelected);
    T pick(static_cast<long>(weight) + nextRes.weight, newSelected);
    return memo[i][quota] =
               (pick.weight > skip.weight ||
                pick.weight == skip.weight && pick.selected < skip.selected)
                   ? pick
                   : skip;
  }

  // Binary searches the first interval that starts after `rightBoundary`.
  int findFirstGreater(const vector<Interval>& intervals, int startFrom,
                       int rightBoundary) {
    int l = startFrom;
    int r = intervals.size();
    while (l < r) {
      const int m = (l + r) / 2;
      if (get<0>(intervals[m]) > rightBoundary)
        r = m;
      else
        l = m + 1;
    }
    return l;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class Solution {
  public int[] maximumWeight(List<List<Integer>> intervals) {
    // Convert input to Interval objects
    List<Interval> indexedIntervals = new ArrayList<>();
    for (int i = 0; i < intervals.size(); ++i) {
      List<Integer> interval = intervals.get(i);
      indexedIntervals.add(new Interval(interval.get(0), interval.get(1), interval.get(2), i));
    }
    indexedIntervals.sort(Comparator.comparingInt(Interval::left));
    T[][] memo = new T[indexedIntervals.size()][5];
    return dp(indexedIntervals, memo, 0, 4).selected.stream().mapToInt(Integer::intValue).toArray();
  }

  private record T(long weight, List<Integer> selected) {}
  private record Interval(int left, int right, int weight, int originalIndex) {}

  private T dp(List<Interval> intervals, T[][] memo, int i, int quota) {
    if (i == intervals.size() || quota == 0)
      return new T(0, List.of());
    if (memo[i][quota] != null)
      return memo[i][quota];

    T skip = dp(intervals, memo, i + 1, quota);

    Interval interval = intervals.get(i);
    final int j = findFirstGreater(intervals, i + 1, interval.right);
    T nextRes = dp(intervals, memo, j, quota - 1);

    List<Integer> newSelected = new ArrayList<>(nextRes.selected);
    newSelected.add(interval.originalIndex);
    Collections.sort(newSelected);
    T pick = new T(interval.weight + nextRes.weight, newSelected);
    return memo[i][quota] =
               (pick.weight > skip.weight ||
                (pick.weight == skip.weight && compareLists(pick.selected, skip.selected) < 0))
                   ? pick
                   : skip;
  }

  // Binary searches the first interval that starts after `rightBoundary`.
  private int findFirstGreater(List<Interval> intervals, int startFrom, int rightBoundary) {
    int l = startFrom;
    int r = intervals.size();
    while (l < r) {
      final int m = (l + r) / 2;
      if (intervals.get(m).left > rightBoundary)
        r = m;
      else
        l = m + 1;
    }
    return l;
  }

  // Compares two lists of integers lexicographically.
  private int compareLists(List<Integer> list1, List<Integer> list2) {
    final int minSize = Math.min(list1.size(), list2.size());
    for (int i = 0; i < minSize; ++i) {
      final int comparison = Integer.compare(list1.get(i), list2.get(i));
      if (comparison != 0)
        return comparison;
    }
    return Integer.compare(list1.size(), list2.size());
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from dataclasses import dataclass


@dataclass(frozen=True)
class T:
  weight: int
  selected: tuple[int]

  def __iter__(self):
    yield self.weight
    yield self.selected


class Solution:
  def maximumWeight(self, intervals: list[list[int]]) -> list[int]:
    intervals = sorted((*interval, i) for i, interval in enumerate(intervals))

    @functools.lru_cache(None)
    def dp(i: int, quota: int) -> T:
      """
      Returns the maximum weight and the selected intervals for intervals[i..n),
      where `quota` is the number of intervals that can be selected.
      """
      if i == len(intervals) or quota == 0:
        return T(0, ())

      skip = dp(i + 1, quota)

      _, r, weight, originalIndex = intervals[i]
      j = bisect.bisect_right(intervals, (r, math.inf))
      nextRes = dp(j, quota - 1)
      pick = T(weight + nextRes.weight,
               sorted((originalIndex, *nextRes.selected)))
      return (pick if (pick.weight > skip.weight or
                       pick.weight == skip.weight and pick.selected < skip.selected)
              else skip)

    return list(dp(0, 4).selected)