Skip to content

3469. Find Minimum Cost to Remove Array Elements 👍

  • Time: $O(n^2)$
  • Space: $O(n^2)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
 public:
  // Main function to start the calculation
  int minCost(vector<int>& nums) {
    const int n = nums.size();
    vector<vector<int>> mem(n + 1, vector<int>(n + 1, -1));
    return minCost(/*last=*/0, 1, nums, mem);
  }

 private:
  int minCost(int last, int i, vector<int>& nums, vector<vector<int>>& mem) {
    const int n = nums.size();
    if (i == n)  // Single element left.
      return nums[last];
    if (i == n - 1)  // Two elements left.
      return max(nums[last], nums[i]);
    if (mem[i][last] != -1)
      return mem[i][last];
    const int a = max(nums[i], nums[i + 1]) + minCost(last, i + 2, nums, mem);
    const int b = max(nums[last], nums[i]) + minCost(i + 1, i + 2, nums, mem);
    const int c = max(nums[last], nums[i + 1]) + minCost(i, i + 2, nums, mem);
    return mem[i][last] = min({a, b, c});
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
  public int minCost(int[] nums) {
    final int n = nums.length;
    Integer[][] mem = new Integer[n + 1][n + 1];
    return minCost(/*last=*/0, 1, nums, mem);
  }

  private int minCost(int last, int i, int[] nums, Integer[][] mem) {
    final int n = nums.length;
    if (i == n) // Single element left.
      return nums[last];
    if (i == n - 1) // Two elements left.
      return Math.max(nums[last], nums[i]);
    if (mem[i][last] != null)
      return mem[i][last];
    final int a = Math.max(nums[i], nums[i + 1]) + minCost(last, i + 2, nums, mem);
    final int b = Math.max(nums[last], nums[i]) + minCost(i + 1, i + 2, nums, mem);
    final int c = Math.max(nums[last], nums[i + 1]) + minCost(i, i + 2, nums, mem);
    return mem[i][last] = Math.min(a, Math.min(b, c));
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class Solution:
  def minCost(self, nums: list[int]) -> int:
    n = len(nums)

    @functools.lru_cache(None)
    def dp(last: int, i: int) -> int:
      if i == n:  # Single element left.
        return nums[last]
      if i == n - 1:  # Two elements left.
        return max(nums[last], nums[i])
      a = max(nums[i], nums[i + 1]) + dp(last, i + 2)
      b = max(nums[last], nums[i]) + dp(i + 1, i + 2)
      c = max(nums[last], nums[i + 1]) + dp(i, i + 2)
      return min(a, b, c)

    return dp(0, 1)