Skip to content

410. Split Array Largest Sum 👍

Approach 1: Top-down

  • Time: $O(kn^2)$
  • Space: $O(nk)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution {
 public:
  int splitArray(vector<int>& nums, int k) {
    const int n = nums.size();
    vector<vector<int>> mem(n + 1, vector<int>(k + 1, INT_MAX));
    vector<int> prefix(n + 1);
    partial_sum(nums.begin(), nums.end(), prefix.begin() + 1);
    return splitArray(nums, n, k, prefix, mem);
  }

 private:
  // Returns the minimum of the maximum sum to split the first i numbers into k
  // groups.
  int splitArray(const vector<int>& nums, int i, int k,
                 const vector<int>& prefix, vector<vector<int>>& mem) {
    if (k == 1)
      return prefix[i];
    if (mem[i][k] < INT_MAX)
      return mem[i][k];

    // Try all the possible partitions.
    for (int j = k - 1; j < i; ++j)
      mem[i][k] = min(mem[i][k], max(splitArray(nums, j, k - 1, prefix, mem),
                                     prefix[i] - prefix[j]));

    return mem[i][k];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution {
  public int splitArray(int[] nums, int k) {
    final int n = nums.length;
    int[][] mem = new int[n + 1][k + 1];
    int[] prefix = new int[n + 1];
    Arrays.stream(mem).forEach(A -> Arrays.fill(A, Integer.MAX_VALUE));

    for (int i = 0; i < n; ++i)
      prefix[i + 1] = nums[i] + prefix[i];

    return splitArray(nums, n, k, prefix, mem);
  }

  // Returns the minimum of the maximum sum to split the first i numbers into k
  // groups.
  private int splitArray(int[] nums, int i, int k, int[] prefix, int[][] mem) {
    if (k == 1)
      return prefix[i];
    if (mem[i][k] < Integer.MAX_VALUE)
      return mem[i][k];

    // Try all the possible partitions.
    for (int j = k - 1; j < i; ++j)
      mem[i][k] = Math.min(mem[i][k],                                        //
                           Math.max(splitArray(nums, j, k - 1, prefix, mem), //
                                    prefix[i] - prefix[j]));

    return mem[i][k];
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution:
  def splitArray(self, nums: list[int], k: int) -> int:
    prefix = list(itertools.accumulate(nums, initial=0))

    @functools.lru_cache(None)
    def dp(i: int, k: int) -> int:
      """
      Returns the minimum of the maximum sum to split the first i numbers into
      k groups.
      """
      if k == 1:
        return prefix[i]
      return min(max(dp(j, k - 1), prefix[i] - prefix[j])
                 for j in range(k - 1, i))

    return dp(len(nums), k)

Approach 2: Bottom-up

  • Time: $O(kn^2)$
  • Space: $O(nk)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
 public:
  int splitArray(vector<int>& nums, int k) {
    const int n = nums.size();
    // dp[i][k] := the minimum of the maximum sum to split the first i numbers
    // into k groups
    vector<vector<long>> dp(n + 1, vector<long>(k + 1, INT_MAX));
    vector<long> prefix(n + 1);

    partial_sum(nums.begin(), nums.end(), prefix.begin() + 1);

    for (int i = 1; i <= n; ++i)
      dp[i][1] = prefix[i];

    for (int l = 2; l <= k; ++l)
      for (int i = l; i <= n; ++i)
        for (int j = l - 1; j < i; ++j)
          dp[i][l] = min(dp[i][l], max(dp[j][l - 1], prefix[i] - prefix[j]));

    return dp[n][k];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
  public int splitArray(int[] nums, int k) {
    final int n = nums.length;
    // dp[i][k] := the minimum of the maximum sum to split the first i numbers
    // into k groups
    int[][] dp = new int[n + 1][k + 1];
    Arrays.stream(dp).forEach(A -> Arrays.fill(A, Integer.MAX_VALUE));
    int[] prefix = new int[n + 1];

    for (int i = 1; i <= n; ++i) {
      prefix[i] = prefix[i - 1] + nums[i - 1];
      dp[i][1] = prefix[i];
    }

    for (int l = 2; l <= k; ++l)
      for (int i = l; i <= n; ++i)
        for (int j = l - 1; j < i; ++j)
          dp[i][l] = Math.min(dp[i][l], Math.max(dp[j][l - 1], prefix[i] - prefix[j]));

    return dp[n][k];
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class Solution:
  def splitArray(self, nums: list[int], k: int) -> int:
    n = len(nums)
    # dp[i][k] := the minimum of the maximum sum to split the first i numbers
    # into k groups
    dp = [[math.inf] * (k + 1) for _ in range(n + 1)]
    prefix = list(itertools.accumulate(nums, initial=0))

    for i in range(1, n + 1):
      dp[i][1] = prefix[i]

    for l in range(2, k + 1):
      for i in range(l, n + 1):
        for j in range(l - 1, i):
          dp[i][l] = min(dp[i][l], max(dp[j][l - 1], prefix[i] - prefix[j]))

    return dp[n][k]
  • Time: $O(n\log(\Sigma |\texttt{nums}|))$
  • Space: $O(1)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
 public:
  int splitArray(vector<int>& nums, int k) {
    int l = ranges::max(nums);
    int r = accumulate(nums.begin(), nums.end(), 0) + 1;

    while (l < r) {
      const int m = (l + r) / 2;
      if (numGroups(nums, m) > k)
        l = m + 1;
      else
        r = m;
    }

    return l;
  }

 private:
  int numGroups(const vector<int>& nums, int maxSumInGroup) {
    int groupCount = 1;
    int sumInGroup = 0;

    for (const int num : nums)
      if (sumInGroup + num <= maxSumInGroup) {
        sumInGroup += num;
      } else {
        ++groupCount;
        sumInGroup = num;
      }

    return groupCount;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution {
  public int splitArray(int[] nums, int k) {
    int l = Arrays.stream(nums).max().getAsInt();
    int r = Arrays.stream(nums).sum() + 1;

    while (l < r) {
      final int m = (l + r) / 2;
      if (numGroups(nums, m) > k)
        l = m + 1;
      else
        r = m;
    }

    return l;
  }

  private int numGroups(int[] nums, int maxSumInGroup) {
    int groupCount = 1;
    int sumInGroup = 0;

    for (final int num : nums)
      if (sumInGroup + num <= maxSumInGroup) {
        sumInGroup += num;
      } else {
        ++groupCount;
        sumInGroup = num;
      }

    return groupCount;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution:
  def splitArray(self, nums: list[int], k: int) -> int:
    l = max(nums)
    r = sum(nums) + 1

    def numGroups(maxSumInGroup: int) -> int:
      groupCount = 1
      sumInGroup = 0

      for num in nums:
        if sumInGroup + num <= maxSumInGroup:
          sumInGroup += num
        else:
          groupCount += 1
          sumInGroup = num

      return groupCount

    while l < r:
      m = (l + r) // 2
      if numGroups(m) > k:
        l = m + 1
      else:
        r = m

    return l