Skip to content

2157. Groups of Strings 👍

  • Time: $O(26n \cdot \alpha(n))$
  • Space: $O(26n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class UnionFind {
 public:
  UnionFind(int n) : count(n), id(n), sz(n, 1) {
    iota(id.begin(), id.end(), 0);
  }

  void unionBySize(int u, int v) {
    const int i = find(u);
    const int j = find(v);
    if (i == j)
      return;
    if (sz[i] < sz[j]) {
      sz[j] += sz[i];
      id[i] = j;
    } else {
      sz[i] += sz[j];
      id[j] = i;
    }
    --count;
  }

  int getCount() const {
    return count;
  }

  int getMaxSize() const {
    return ranges::max(sz);
  }

 private:
  int count;
  vector<int> id;
  vector<int> sz;

  int find(int u) {
    return id[u] == u ? u : id[u] = find(id[u]);
  }
};

class Solution {
 public:
  vector<int> groupStrings(vector<string>& words) {
    UnionFind uf(words.size());
    unordered_map<int, int> maskToIndex;
    unordered_map<int, int> deletedMaskToIndex;

    for (int i = 0; i < words.size(); ++i) {
      const int mask = getMask(words[i]);
      for (int j = 0; j < 26; ++j)
        if (mask >> j & 1) {
          // Going to delete this bit.
          const int m = mask ^ 1 << j;
          if (const auto it = maskToIndex.find(m); it != maskToIndex.cend())
            uf.unionBySize(i, it->second);
          if (const auto it = deletedMaskToIndex.find(m);
              it != deletedMaskToIndex.cend())
            uf.unionBySize(i, it->second);
          else
            deletedMaskToIndex[m] = i;
        } else {
          // Going to add this bit.
          const int m = mask | 1 << j;
          if (const auto it = maskToIndex.find(m); it != maskToIndex.cend())
            uf.unionBySize(i, it->second);
        }
      maskToIndex[mask] = i;
    }

    return {uf.getCount(), uf.getMaxSize()};
  }

 private:
  int getMask(const string& s) {
    int mask = 0;
    for (const char c : s)
      mask |= 1 << c - 'a';
    return mask;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class UnionFind {
  public UnionFind(int n) {
    count = n;
    id = new int[n];
    sz = new int[n];
    for (int i = 0; i < n; ++i)
      id[i] = i;
    for (int i = 0; i < n; ++i)
      sz[i] = 1;
  }

  public void unionBySize(int u, int v) {
    final int i = find(u);
    final int j = find(v);
    if (i == j)
      return;
    if (sz[i] < sz[j]) {
      sz[j] += sz[i];
      id[i] = j;
    } else {
      sz[i] += sz[j];
      id[j] = i;
    }
    --count;
  }

  public int getCount() {
    return count;
  }

  public int getMaxSize() {
    return Arrays.stream(sz).max().getAsInt();
  }

  private int count;
  private int[] id;
  private int[] sz;

  private int find(int u) {
    return id[u] == u ? u : (id[u] = find(id[u]));
  }
}

class Solution {
  public int[] groupStrings(String[] words) {
    UnionFind uf = new UnionFind(words.length);
    Map<Integer, Integer> maskToIndex = new HashMap<>();
    Map<Integer, Integer> deletedMaskToIndex = new HashMap<>();

    for (int i = 0; i < words.length; ++i) {
      final int mask = getMask(words[i]);
      for (int j = 0; j < 26; ++j)
        if ((mask >> j & 1) == 1) {
          // Going to delete this bit.
          final int m = mask ^ 1 << j;
          if (maskToIndex.containsKey(m))
            uf.unionBySize(i, maskToIndex.get(m));
          if (deletedMaskToIndex.containsKey(m))
            uf.unionBySize(i, deletedMaskToIndex.get(m));
          else
            deletedMaskToIndex.put(m, i);
        } else {
          // Going to add this bit.
          final int m = mask | 1 << j;
          if (maskToIndex.containsKey(m))
            uf.unionBySize(i, maskToIndex.get(m));
        }
      maskToIndex.put(mask, i);
    }

    return new int[] {uf.getCount(), uf.getMaxSize()};
  }

  private int getMask(final String s) {
    int mask = 0;
    for (final char c : s.toCharArray())
      mask |= 1 << c - 'a';
    return mask;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class UnionFind:
  def __init__(self, n: int):
    self.count = n
    self.id = list(range(n))
    self.sz = [1] * n

  def unionBySize(self, u: int, v: int) -> None:
    i = self._find(u)
    j = self._find(v)
    if i == j:
      return
    if self.sz[i] < self.sz[j]:
      self.sz[j] += self.sz[i]
      self.id[i] = j
    else:
      self.sz[i] += self.sz[j]
      self.id[j] = i
    self.count -= 1

  def _find(self, u: int) -> int:
    if self.id[u] != u:
      self.id[u] = self._find(self.id[u])
    return self.id[u]


class Solution:
  def groupStrings(self, words: list[str]) -> list[int]:
    uf = UnionFind(len(words))

    def getMask(s: str) -> int:
      mask = 0
      for c in s:
        mask |= 1 << string.ascii_lowercase.index(c)
      return mask

    def getAddedMasks(mask: int):
      for i in range(26):
        if not (mask >> i & 1):
          yield mask | 1 << i

    def getDeletedMasks(mask: int):
      for i in range(26):
        if mask >> i & 1:
          yield mask ^ 1 << i

    maskToIndex = {getMask(word): i for i, word in enumerate(words)}
    deletedMaskToIndex = {}

    for i, word in enumerate(words):
      mask = getMask(word)
      for m in getAddedMasks(mask):
        if m in maskToIndex:
          uf.unionBySize(i, maskToIndex[m])
      for m in getDeletedMasks(mask):
        if m in maskToIndex:
          uf.unionBySize(i, maskToIndex[m])
        if m in deletedMaskToIndex:
          uf.unionBySize(i, deletedMaskToIndex[m])
        else:
          deletedMaskToIndex[m] = i

    return [uf.count, max(uf.sz)]