Skip to content

2572. Count the Number of Square-Free Subsets

  • Time: $O(n\cdot 2^{11}) = O(n)$
  • Space: $O(n\cdot 2^{11}) = O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Solution {
 public:
  int squareFreeSubsets(vector<int>& nums) {
    vector<vector<int>> mem(nums.size(),
                            vector<int>(1 << (kPrimesCount + 1), -1));
    vector<int> masks;

    for (const int num : nums)
      masks.push_back(getMask(num));

    // -1 means that we take no number.
    // `used` is initialized to 1 so that -1 & 1 = 1 instead of 0.
    return (squareFreeSubsets(masks, 0, /*used=*/1, mem) - 1 + kMod) % kMod;
  }

 private:
  static constexpr int kMod = 1'000'000'007;
  static constexpr int kPrimesCount = 10;
  static constexpr int primes[] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29};

  int squareFreeSubsets(const vector<int>& masks, int i, int used,
                        vector<vector<int>>& mem) {
    if (i == masks.size())
      return 1;
    if (mem[i][used] != -1)
      return mem[i][used];
    const int pick = (masks[i] & used) == 0
                         ? squareFreeSubsets(masks, i + 1, used | masks[i], mem)
                         : 0;
    const int skip = squareFreeSubsets(masks, i + 1, used, mem);
    return mem[i][used] = (pick + skip) % kMod;
  }

  // e.g. num = 10 = 2 * 5, so mask = 0b101 -> 0b1010 (append a 0)
  //      num = 15 = 3 * 5, so mask = 0b110 -> 0b1100 (append a 0)
  //      num = 25 = 5 * 5, so mask =  0b-1 -> 0b1..1 (invalid)
  int getMask(int num) {
    int mask = 0;
    for (int i = 0; i < sizeof(primes) / sizeof(int); ++i) {
      int rootCount = 0;
      while (num % primes[i] == 0) {
        num /= primes[i];
        ++rootCount;
      }
      if (rootCount >= 2)
        return -1;
      if (rootCount == 1)
        mask |= 1 << i;
    }
    return mask << 1;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Solution {
  public int squareFreeSubsets(int[] nums) {
    Integer[][] mem = new Integer[nums.length][1 << (kPrimesCount + 1)];
    int[] masks = new int[nums.length];

    for (int i = 0; i < nums.length; ++i)
      masks[i] = getMask(nums[i]);

    // -1 means that we take no number.
    // `used` is initialized to 1 so that -1 & 1 = 1 instead of 0.
    return (squareFreeSubsets(masks, 0, /*used=*/1, mem) - 1 + kMod) % kMod;
  }

  private static final int kMod = 1_000_000_007;
  private static final int kPrimesCount = 10;
  private static final int[] primes = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29};

  private int squareFreeSubsets(int[] masks, int i, int used, Integer[][] mem) {
    if (i == masks.length)
      return 1;
    if (mem[i][used] != null)
      return mem[i][used];
    final int pick =
        (masks[i] & used) == 0 ? squareFreeSubsets(masks, i + 1, used | masks[i], mem) : 0;
    final int skip = squareFreeSubsets(masks, i + 1, used, mem);
    return mem[i][used] = (pick + skip) % kMod;
  }

  // e.g. num = 10 = 2 * 5, so mask = 0b101 -> 0b1010 (append a 0)
  //      num = 15 = 3 * 5, so mask = 0b110 -> 0b1100 (append a 0)
  //      num = 25 = 5 * 5, so mask =  0b-1 -> 0b1..1 (invalid)
  private int getMask(int num) {
    int mask = 0;
    for (int i = 0; i < primes.length; ++i) {
      int rootCount = 0;
      while (num % primes[i] == 0) {
        num /= primes[i];
        ++rootCount;
      }
      if (rootCount >= 2)
        return -1;
      if (rootCount == 1)
        mask |= 1 << i;
    }
    return mask << 1;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution:
  def squareFreeSubsets(self, nums: List[int]) -> int:
    kMod = 1_000_000_007
    primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]

    def getMask(num: int) -> int:
      """
      e.g. num = 10 = 2 * 5, so mask = 0b101 . 0b1010 (append a 0)
           num = 15 = 3 * 5, so mask = 0b110 . 0b1100 (append a 0)
           num = 25 = 5 * 5, so mask =  (-1)2 . (1..1)2 (invalid)
      """
      mask = 0
      for i, prime in enumerate(primes):
        rootCount = 0
        while num % prime == 0:
          num //= prime
          rootCount += 1
        if rootCount >= 2:
          return -1
        if rootCount == 1:
          mask |= 1 << i
      return mask << 1

    masks = [getMask(num) for num in nums]

    @functools.lru_cache(None)
    def dp(i: int, used: int) -> int:
      if i == len(masks):
        return 1
      pick = dp(i + 1, used | masks[i]) if (masks[i] & used) == 0 else 0
      skip = dp(i + 1, used)
      return (pick + skip) % kMod

    # -1 means that we take no number.
    # `used` is initialized to 1 so that -1 & 1 = 1 instead of 0.
    return (dp(0, 1) - 1 + kMod) % kMod