Skip to content

3343. Count Number of Balanced Permutations 👍

  • Time: $O(n^3)$
  • Space: $O(n^3)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class Solution {
 public:
  int countBalancedPermutations(string num) {
    vector<int> nums = getNums(num);
    const int sum = accumulate(nums.begin(), nums.end(), 0);
    if (sum % 2 == 1)
      return 0;

    ranges::sort(nums, greater<>());

    const int even = (nums.size() + 1) / 2;
    const int odd = nums.size() / 2;
    const int evenBalance = sum / 2;
    vector<vector<vector<long>>> mem(
        even + 1,
        vector<vector<long>>(odd + 1, vector<long>(evenBalance + 1, -1)));
    const long perm = getPerm(nums);
    return countBalancedPermutations(nums, even, odd, evenBalance, mem) *
           modInverse(perm) % kMod;
  }

 private:
  static constexpr int kMod = 1'000'000'007;

  // Returns the number of permutations where there are `even` even indices
  // left, `odd` odd indices left, and `evenBalance` is the target sum of the
  // remaining numbers to be placed in even indices.
  long countBalancedPermutations(const vector<int>& nums, int even, int odd,
                                 int evenBalance,
                                 vector<vector<vector<long>>>& mem) {
    if (evenBalance < 0)
      return 0;
    if (even == 0)
      return evenBalance == 0 ? factorial(odd) : 0;
    const int index = nums.size() - (even + odd);
    if (odd == 0) {
      long remainingSum = 0;
      for (int i = index; i < nums.size(); ++i)
        remainingSum += nums[i];
      return (remainingSum == evenBalance) ? factorial(even) : 0;
    }
    if (mem[even][odd][evenBalance] != -1)
      return mem[even][odd][evenBalance];
    const long placeEven =
        countBalancedPermutations(nums, even - 1, odd,
                                  evenBalance - nums[index], mem) *
        even % kMod;
    const long placeOdd =
        countBalancedPermutations(nums, even, odd - 1, evenBalance, mem) * odd %
        kMod;
    return mem[even][odd][evenBalance] = (placeEven + placeOdd) % kMod;
  }

  vector<int> getNums(const string& num) {
    vector<int> nums;
    for (const char c : num)
      nums.push_back(c - '0');
    return nums;
  }

  long getPerm(const vector<int>& nums) {
    long res = 1;
    vector<int> count(10);
    for (const int num : nums)
      ++count[num];
    for (const int freq : count)
      res = res * factorial(freq) % kMod;
    return res;
  }

  long factorial(int n) {
    long res = 1;
    for (int i = 2; i <= n; ++i)
      res = res * i % kMod;
    return res;
  }

  long modInverse(long a) {
    long m = kMod;
    long y = 0;
    long x = 1;
    while (a > 1) {
      const long q = a / m;
      long t = m;
      m = a % m;
      a = t;
      t = y;
      y = x - q * y;
      x = t;
    }
    return x < 0 ? x + kMod : x;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class Solution {
  public int countBalancedPermutations(String num) {
    int[] nums = getNums(num);
    final int sum = Arrays.stream(nums).sum();
    if (sum % 2 == 1)
      return 0;

    Arrays.sort(nums);
    reverse(nums, 0, nums.length - 1);

    final int even = (nums.length + 1) / 2;
    final int odd = nums.length / 2;
    final int evenBalance = sum / 2;
    Long[][][] mem = new Long[even + 1][odd + 1][evenBalance + 1];
    final long perm = getPerm(nums);
    return (
        int) ((countBalancedPermutations(nums, even, odd, evenBalance, mem) * modInverse(perm)) %
              MOD);
  }

  private static final int MOD = 1_000_000_007;

  // Returns the number of permutations where there are `even` even indices
  // left, `odd` odd indices left, and `evenBalance` is the target sum of the
  // remaining numbers to be placed in even indices.
  private long countBalancedPermutations(int[] nums, int even, int odd, int evenBalance,
                                         Long[][][] mem) {
    if (evenBalance < 0)
      return 0;
    if (even == 0)
      return evenBalance == 0 ? factorial(odd) : 0;
    final int index = nums.length - (even + odd);
    if (odd == 0) {
      long remainingSum = 0;
      for (int i = index; i < nums.length; ++i)
        remainingSum += nums[i];
      return remainingSum == evenBalance ? factorial(even) : 0;
    }
    if (mem[even][odd][evenBalance] != null)
      return mem[even][odd][evenBalance];
    final long placeEven =
        countBalancedPermutations(nums, even - 1, odd, evenBalance - nums[index], mem) * even % MOD;
    final long placeOdd =
        countBalancedPermutations(nums, even, odd - 1, evenBalance, mem) * odd % MOD;
    return mem[even][odd][evenBalance] = (placeEven + placeOdd) % MOD;
  }

  private int[] getNums(String num) {
    int[] nums = new int[num.length()];
    for (int i = 0; i < num.length(); ++i)
      nums[i] = num.charAt(i) - '0';
    return nums;
  }

  private long getPerm(int[] nums) {
    long res = 1;
    int[] count = new int[10];
    for (final int num : nums)
      ++count[num];
    for (final int freq : count)
      res = res * factorial(freq) % MOD;
    return res;
  }

  private long factorial(int n) {
    long res = 1;
    for (int i = 2; i <= n; ++i)
      res = res * i % MOD;
    return res;
  }

  private long modInverse(long a) {
    long m = MOD;
    long y = 0;
    long x = 1;
    while (a > 1) {
      final long q = a / m;
      long t = m;
      m = a % m;
      a = t;
      t = y;
      y = x - q * y;
      x = t;
    }

    return x < 0 ? x + MOD : x;
  }

  private void reverse(int[] nums, int l, int r) {
    while (l < r)
      swap(nums, l++, r--);
  }

  private void swap(int[] nums, int i, int j) {
    final int temp = nums[i];
    nums[i] = nums[j];
    nums[j] = temp;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution:
  def countBalancedPermutations(self, num: str) -> int:
    nums = list(map(int, num))
    summ = sum(nums)
    if summ % 2 == 1:
      return 0

    nums.sort(reverse=True)

    @functools.lru_cache(None)
    def dp(even: int, odd: int, evenBalance: int) -> int:
      """
      Returns the number of permutations where there are `even` even indices
      left, `odd` odd indices left, and `evenBalance` is the target sum of the
      remaining numbers to be placed in even indices.
      """
      if evenBalance < 0:
        return 0
      if even == 0:
        return (evenBalance == 0) * math.factorial(odd)
      if odd == 0:
        return (sum(nums[-(even + odd):]) == evenBalance) * math.factorial(even)
      return (dp(even - 1, odd, evenBalance - nums[-(odd + even)]) * even +
              dp(even, odd - 1, evenBalance) * odd)

    MOD = 1_000_000_007
    perm = functools.reduce(lambda x, y: x * math.factorial(y),
                            collections.Counter(nums).values(), 1)
    return (dp(even=(len(nums) + 1) // 2,
               odd=len(nums) // 2,
               evenBalance=summ // 2) // perm) % MOD