Skip to content

3410. Maximize Subarray Sum After Removing All Occurrences of One Element 👍

Approach 1: Segment Tree

  • Time: $O(n\log n)$
  • Space: $O(n)$
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
// Same as 53. Maximum Subarray
struct T {
  long sum;
  long maxSubarraySumPrefix;
  long maxSubarraySumSuffix;
  long maxSubarraySum;
  T() = default;
  T(int num)
      : sum(num),
        maxSubarraySumPrefix(num),
        maxSubarraySumSuffix(num),
        maxSubarraySum(num) {}
  T(long sum, long prefix, long suffix, long maxSum)
      : sum(sum),
        maxSubarraySumPrefix(prefix),
        maxSubarraySumSuffix(suffix),
        maxSubarraySum(maxSum) {}
};

class SegmentTree {
 public:
  SegmentTree(const vector<int>& nums) : n(nums.size()), tree(nums.size() * 4) {
    build(nums, 0, 0, n - 1);
  }

  // Updates nums[i] to val.
  void update(int i, int val) {
    update(0, 0, n - 1, i, val);
  }

  long getMaxSubarraySum() const {
    return tree[0].maxSubarraySum;
  }

 private:
  const int n;     // the size of the input array
  vector<T> tree;  // the segment tree

  void build(const vector<int>& nums, int treeIndex, int lo, int hi) {
    if (lo == hi) {
      tree[treeIndex] = T(nums[lo]);
      return;
    }
    const int mid = (lo + hi) / 2;
    build(nums, 2 * treeIndex + 1, lo, mid);
    build(nums, 2 * treeIndex + 2, mid + 1, hi);
    tree[treeIndex] = merge(tree[2 * treeIndex + 1], tree[2 * treeIndex + 2]);
  }

  void update(int treeIndex, int lo, int hi, int i, int val) {
    if (lo == hi) {
      tree[treeIndex] = T(val);
      return;
    }
    const int mid = (lo + hi) / 2;
    if (i <= mid)
      update(2 * treeIndex + 1, lo, mid, i, val);
    else
      update(2 * treeIndex + 2, mid + 1, hi, i, val);
    tree[treeIndex] = merge(tree[2 * treeIndex + 1], tree[2 * treeIndex + 2]);
  }

  T merge(const T& left, const T& right) const {
    return T(
        left.sum + right.sum,
        max(left.maxSubarraySumPrefix, left.sum + right.maxSubarraySumPrefix),
        max(right.maxSubarraySumSuffix, right.sum + left.maxSubarraySumSuffix),
        max({left.maxSubarraySum, right.maxSubarraySum,
             left.maxSubarraySumSuffix + right.maxSubarraySumPrefix}));
  }
};

class Solution {
 public:
  long long maxSubarraySum(vector<int>& nums) {
    const bool allPositives =
        ranges::all_of(nums, [](int num) { return num >= 0; });
    const long sum = accumulate(nums.begin(), nums.end(), 0L);
    if (allPositives)
      return sum;
    const int maxNum = ranges::max(nums);
    if (maxNum < 0)
      return maxNum;

    long ans = LONG_MIN;
    unordered_map<int, vector<int>> numToIndices;
    SegmentTree tree(nums);

    for (int i = 0; i < nums.size(); ++i)
      numToIndices[nums[i]].push_back(i);

    for (const auto& [num, indices] : numToIndices) {
      for (const int index : indices)
        tree.update(index, 0);
      ans = max(ans, tree.getMaxSubarraySum());
      for (const int index : indices)
        tree.update(index, num);
    }

    return ans;
  }
};

Approach 2: Kadane's

  • Time: $O(n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution {
 public:
  long long maxSubarraySum(vector<int>& nums) {
    long ans = ranges::max(nums);
    long prefix = 0;
    long minPrefix = 0;
    // the minimum prefix sum that can have a negative number removed
    long modifiedMinPrefix = 0;
    unordered_map<int, int> count;
    // minPrefixPlusRemoval[num] := the minimum prefix sum plus removed `num`
    unordered_map<int, long> minPrefixPlusRemoval;

    for (const int num : nums) {
      prefix += num;
      ans = max(ans, prefix - modifiedMinPrefix);
      if (num < 0) {
        ++count[num];
        minPrefixPlusRemoval[num] =
            min(minPrefixPlusRemoval[num], minPrefix) + num;
        modifiedMinPrefix =
            min({modifiedMinPrefix, count[num] * static_cast<long>(num),
                 minPrefixPlusRemoval[num]});
      }
      minPrefix = min(minPrefix, prefix);
      modifiedMinPrefix = min(modifiedMinPrefix, minPrefix);
    }

    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution {
  public long maxSubarraySum(int[] nums) {
    long ans = Arrays.stream(nums).max().getAsInt();
    long prefix = 0;
    long minPrefix = 0;
    // the minimum prefix sum that can have a negative number removed
    long modifiedMinPrefix = 0;
    Map<Integer, Integer> count = new HashMap<>();
    // minPrefixPlusRemoval[num] := the minimum prefix sum plus removed `num`
    Map<Integer, Long> minPrefixPlusRemoval = new HashMap<>();

    for (int num : nums) {
      prefix += num;
      ans = Math.max(ans, prefix - modifiedMinPrefix);
      if (num < 0) {
        count.merge(num, 1, Integer::sum);
        minPrefixPlusRemoval.put(
            num, Math.min(minPrefixPlusRemoval.getOrDefault(num, 0L), minPrefix) + num);
        modifiedMinPrefix = Math.min(modifiedMinPrefix,
                                     Math.min(count.get(num) * num, minPrefixPlusRemoval.get(num)));
      }
      minPrefix = Math.min(minPrefix, prefix);
      modifiedMinPrefix = Math.min(modifiedMinPrefix, minPrefix);
    }

    return ans;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution:
  def maxSubarraySum(self, nums: list[int]) -> int:
    ans = max(nums)
    prefix = 0
    minPrefix = 0
    # the minimum prefix sum that can have a negative number removed
    modifiedMinPrefix = 0
    count = collections.Counter()
    # minPrefixPlusRemoval[num] := the minimum prefix sum plus removed `num`
    minPrefixPlusRemoval = {}

    for num in nums:
      prefix += num
      ans = max(ans, prefix - modifiedMinPrefix)
      if num < 0:
        count[num] += 1
        minPrefixPlusRemoval[num] = (
            min(minPrefixPlusRemoval.get(num, 0), minPrefix) + num)
        modifiedMinPrefix = min(modifiedMinPrefix,
                                count[num] * num,
                                minPrefixPlusRemoval[num])
      minPrefix = min(minPrefix, prefix)
      modifiedMinPrefix = min(modifiedMinPrefix, minPrefix)

    return ans