Skip to content

410. Split Array Largest Sum 👍

Approach 1: Top-down

  • Time: $O(mn^2)$
  • Space: $O(mn)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution {
 public:
  int splitArray(vector<int>& nums, int m) {
    const int n = nums.size();
    // dp[i][k] := min of largest sum to split first i nums into k groups
    dp.resize(n + 1, vector<int>(m + 1, INT_MAX));
    prefix.resize(n + 1);

    partial_sum(begin(nums), end(nums), begin(prefix) + 1);
    return splitArray(nums, n, m);
  }

 private:
  vector<vector<int>> dp;
  vector<int> prefix;

  int splitArray(const vector<int>& nums, int i, int k) {
    if (k == 1)
      return prefix[i];
    if (dp[i][k] < INT_MAX)
      return dp[i][k];

    // Try all possible partitions
    for (int j = k - 1; j < i; ++j)
      dp[i][k] =
          min(dp[i][k], max(splitArray(nums, j, k - 1), prefix[i] - prefix[j]));

    return dp[i][k];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution {
  public int splitArray(int[] nums, int m) {
    final int n = nums.length;
    // dp[i][k] := min of largest sum to split first i nums into k groups
    dp = new int[n + 1][m + 1];
    prefix = new int[n + 1];

    Arrays.stream(dp).forEach(A -> Arrays.fill(A, Integer.MAX_VALUE));

    for (int i = 0; i < n; ++i)
      prefix[i + 1] = nums[i] + prefix[i];

    return splitArray(nums, n, m);
  }

  private int[][] dp;
  private int[] prefix;

  private int splitArray(int[] nums, int i, int k) {
    if (k == 1)
      return prefix[i];
    if (dp[i][k] < Integer.MAX_VALUE)
      return dp[i][k];

    // Try all possible partitions
    for (int j = k - 1; j < i; ++j)
      dp[i][k] = Math.min(dp[i][k], Math.max(splitArray(nums, j, k - 1), prefix[i] - prefix[j]));

    return dp[i][k];
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class Solution:
  def splitArray(self, nums: List[int], m: int) -> int:
    n = len(nums)
    prefix = [0] + list(itertools.accumulate(nums))

    # dp(i, k) := min of largest sum to split first i nums into k groups
    @functools.lru_cache(None)
    def dp(i: int, k: int) -> int:
      if k == 1:
        return prefix[i]

      ans = math.inf

      # Try all possible partitions
      for j in range(k - 1, i):
        ans = min(ans, max(dp(j, k - 1), prefix[i] - prefix[j]))

      return ans

    return dp(n, m)

Approach 2: Bottom-up

  • Time: $O(mn^2)$
  • Space: $O(mn)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
 public:
  int splitArray(vector<int>& nums, int m) {
    const int n = nums.size();
    // dp[i][k] := min of largest sum to split first i nums into k groups
    vector<vector<long>> dp(n + 1, vector<long>(m + 1, INT_MAX));
    vector<long> prefix(n + 1);

    partial_sum(begin(nums), end(nums), begin(prefix) + 1);

    for (int i = 1; i <= n; ++i)
      dp[i][1] = prefix[i];

    for (int k = 2; k <= m; ++k)
      for (int i = k; i <= n; ++i)
        for (int j = k - 1; j < i; ++j)
          dp[i][k] = min(dp[i][k], max(dp[j][k - 1], prefix[i] - prefix[j]));

    return dp[n][m];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
  public int splitArray(int[] nums, int m) {
    final int n = nums.length;
    // dp[i][k] := min of largest sum to split first i nums into k groups
    int[][] dp = new int[n + 1][m + 1];
    Arrays.stream(dp).forEach(A -> Arrays.fill(A, Integer.MAX_VALUE));
    int[] prefix = new int[n + 1];

    for (int i = 1; i <= n; ++i) {
      prefix[i] = prefix[i - 1] + nums[i - 1];
      dp[i][1] = prefix[i];
    }

    for (int k = 2; k <= m; ++k)
      for (int i = k; i <= n; ++i)
        for (int j = k - 1; j < i; ++j)
          dp[i][k] = Math.min(dp[i][k], Math.max(dp[j][k - 1], prefix[i] - prefix[j]));

    return dp[n][m];
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution:
  def splitArray(self, nums: List[int], m: int) -> int:
    n = len(nums)
    # dp[i][k] := min of largest sum to split first i nums into k groups
    dp = [[math.inf] * (m + 1) for _ in range(n + 1)]
    prefix = [0] + list(itertools.accumulate(nums))

    for i in range(1, n + 1):
      dp[i][1] = prefix[i]

    for k in range(2, m + 1):
      for i in range(k, n + 1):
        for j in range(k - 1, i):
          dp[i][k] = min(dp[i][k], max(dp[j][k - 1], prefix[i] - prefix[j]))

    return dp[n][m]
  • Time: $O(n\log(\Sigma |\texttt{nums}|))$
  • Space: $O(1)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution {
 public:
  int splitArray(vector<int>& nums, int m) {
    int l = *max_element(begin(nums), end(nums));
    int r = accumulate(begin(nums), end(nums), 0) + 1;

    while (l < r) {
      const int mid = (l + r) / 2;
      if (numGroups(nums, mid) > m)
        l = mid + 1;
      else
        r = mid;
    }

    return l;
  }

 private:
  int numGroups(const vector<int>& nums, int maxSumInGroup) {
    int groupCount = 1;
    int sumInGroup = 0;

    for (const int num : nums)
      if (sumInGroup + num <= maxSumInGroup) {
        sumInGroup += num;
      } else {
        ++groupCount;
        sumInGroup = num;
      }

    return groupCount;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution {
  public int splitArray(int[] nums, int m) {
    int l = Arrays.stream(nums).max().getAsInt();
    int r = Arrays.stream(nums).sum() + 1;

    while (l < r) {
      final int mid = (l + r) / 2;
      if (numGroups(nums, mid) > m)
        l = mid + 1;
      else
        r = mid;
    }

    return l;
  }

  private int numGroups(int[] nums, int maxSumInGroup) {
    int groupCount = 1;
    int sumInGroup = 0;

    for (final int num : nums)
      if (sumInGroup + num <= maxSumInGroup) {
        sumInGroup += num;
      } else {
        ++groupCount;
        sumInGroup = num;
      }

    return groupCount;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution:
  def splitArray(self, nums: List[int], m: int) -> int:
    l = max(nums)
    r = sum(nums) + 1

    def numGroups(maxSumInGroup: int) -> int:
      groupCount = 1
      sumInGroup = 0

      for num in nums:
        if sumInGroup + num <= maxSumInGroup:
          sumInGroup += num
        else:
          groupCount += 1
          sumInGroup = num

      return groupCount

    while l < r:
      mid = (l + r) // 2
      if numGroups(mid) > m:
        l = mid + 1
      else:
        r = mid

    return l