Skip to content

1397. Find All Good Strings

  • Time: $O(n \cdot |\texttt{evil}| \cdot 26) = O(n \cdot |\texttt{evil}|)$
  • Space: $O(n \cdot |\texttt{evil}| \cdot 2^2 + |\texttt{evil}| \cdot 26) = O(|\texttt{evil}| \cdot n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class Solution {
 public:
  int findGoodStrings(int n, string s1, string s2, string evil) {
    vector<vector<vector<vector<int>>>> mem(
        n, vector<vector<vector<int>>>(
               evil.length(), vector<vector<int>>(2, vector<int>(2, -1))));
    // nextMatchedCount[i][j] := the number of next matched evil count, where
    // there're i matches with `evil` and the current letter is ('a' + j)
    vector<vector<int>> nextMatchedCount(evil.length(), vector<int>(26, -1));
    return count(s1, s2, evil, 0, 0, true, true, getLPS(evil), nextMatchedCount,
                 mem);
  }

 private:
  static constexpr int kMod = 1'000'000'007;

  // Returns the number of good strings for s[i..n), where there're j matches
  // with `evil`, `isS1Prefix` indicates if the current letter is tightly bound
  // for `s1` and `isS2Prefix` indicates if the current letter is tightly bound
  // for `s2`.
  int count(const string& s1, const string& s2, const string& evil, int i,
            int matchedEvilCount, bool isS1Prefix, bool isS2Prefix,
            const vector<int>& evilLPS, vector<vector<int>>& nextMatchedCount,
            vector<vector<vector<vector<int>>>>& mem) {
    // s[0..i) contains `evil`, so don't consider any ongoing strings.
    if (matchedEvilCount == evil.length())
      return 0;
    // Run out of strings, so contribute one.
    if (i == s1.length())
      return 1;
    int& res = mem[i][matchedEvilCount][isS1Prefix][isS2Prefix];
    if (res != -1)
      return res;
    res = 0;
    const char minLetter = isS1Prefix ? s1[i] : 'a';
    const char maxLetter = isS2Prefix ? s2[i] : 'z';
    for (char c = minLetter; c <= maxLetter; ++c) {
      const int nextMatchedEvilCount = getNextMatchedEvilCount(
          nextMatchedCount, evil, matchedEvilCount, c, evilLPS);
      res += count(s1, s2, evil, i + 1, nextMatchedEvilCount,
                   isS1Prefix && c == s1[i], isS2Prefix && c == s2[i], evilLPS,
                   nextMatchedCount, mem);
      res %= kMod;
    }
    return res;
  }

  // Returns the lps array, where lps[i] is the length of the longest prefix of
  // pattern[0..i] which is also a suffix of this substring.
  vector<int> getLPS(const string& pattern) {
    vector<int> lps(pattern.length());
    for (int i = 1, j = 0; i < pattern.length(); ++i) {
      while (j > 0 && pattern[j] != pattern[i])
        j = lps[j - 1];
      if (pattern[i] == pattern[j])
        lps[i] = ++j;
    }
    return lps;
  }

  // j := the next index we're trying to match with `currLetter`
  int getNextMatchedEvilCount(vector<vector<int>>& nextMatchedCount,
                              const string& evil, int j, char currLetter,
                              const vector<int>& evilLPS) {
    int& res = nextMatchedCount[j][currLetter - 'a'];
    if (res != -1)
      return res;
    while (j > 0 && evil[j] != currLetter)
      j = evilLPS[j - 1];
    return res = (evil[j] == currLetter ? j + 1 : j);
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class Solution {
  public int findGoodStrings(int n, String s1, String s2, String evil) {
    Integer[][][][] mem = new Integer[n][evil.length()][2][2];
    // nextMatchedCount[i][j] := the number of next matched evil count, where
    // there're j matches with `evil` and the current letter is ('a' + j)
    Integer[][] nextMatchedCount = new Integer[evil.length()][26];
    return count(s1, s2, evil, 0, 0, true, true, getLPS(evil), nextMatchedCount, mem);
  }

  private static final int MOD = 1_000_000_007;

  // Returns the number of good strings for s[i..n), where there're j matches
  // with `evil`, `isS1Prefix` indicates if the current letter is tightly bound
  // for `s1` and `isS2Prefix` indicates if the current letter is tightly bound
  // for `s2`.
  private int count(final String s1, final String s2, final String evil, int i,
                    int matchedEvilCount, boolean isS1Prefix, boolean isS2Prefix, int[] evilLPS,
                    Integer[][] nextMatchedCount, Integer[][][][] mem) {
    // s[0..i) contains `evil`, so don't consider any ongoing strings.
    if (matchedEvilCount == evil.length())
      return 0;
    // Run out of strings, so contribute one.
    if (i == s1.length())
      return 1;
    final int k1 = isS1Prefix ? 1 : 0;
    final int k2 = isS2Prefix ? 1 : 0;
    if (mem[i][matchedEvilCount][k1][k2] != null)
      return mem[i][matchedEvilCount][k1][k2];
    mem[i][matchedEvilCount][k1][k2] = 0;
    final char minChar = isS1Prefix ? s1.charAt(i) : 'a';
    final char maxChar = isS2Prefix ? s2.charAt(i) : 'z';
    for (char c = minChar; c <= maxChar; ++c) {
      final int nextMatchedEvilCount =
          getNextMatchedEvilCount(nextMatchedCount, evil, matchedEvilCount, c, evilLPS);
      mem[i][matchedEvilCount][k1][k2] +=
          count(s1, s2, evil, i + 1, nextMatchedEvilCount, isS1Prefix && c == s1.charAt(i),
                isS2Prefix && c == s2.charAt(i), evilLPS, nextMatchedCount, mem);
      mem[i][matchedEvilCount][k1][k2] %= MOD;
    }
    return mem[i][matchedEvilCount][k1][k2];
  }

  // Returns the lps array, where lps[i] is the length of the longest prefix of
  // pattern[0..i] which is also a suffix of this substring.
  private int[] getLPS(final String pattern) {
    int[] lps = new int[pattern.length()];
    for (int i = 1, j = 0; i < pattern.length(); ++i) {
      while (j > 0 && pattern.charAt(j) != pattern.charAt(i))
        j = lps[j - 1];
      if (pattern.charAt(i) == pattern.charAt(j))
        lps[i] = ++j;
    }
    return lps;
  }

  // j := the next index we're trying to match with `currLetter`
  private int getNextMatchedEvilCount(Integer[][] nextMatchedCount, final String evil, int j,
                                      char currChar, int[] lps) {
    if (nextMatchedCount[j][currChar - 'a'] != null)
      return nextMatchedCount[j][currChar - 'a'];
    while (j > 0 && evil.charAt(j) != currChar)
      j = lps[j - 1];
    return nextMatchedCount[j][currChar - 'a'] = (evil.charAt(j) == currChar ? j + 1 : j);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class Solution:
  def findGoodStrings(self, n: int, s1: str, s2: str, evil: str) -> int:
    MOD = 1_000_000_007
    evilLPS = self._getLPS(evil)

    @functools.lru_cache(None)
    def getNextMatchedEvilCount(j: int, currChar: str) -> int:
      """
      Returns the number of next matched evil count, where there're j matches
      with `evil` and the current letter is ('a' + j).
      """
      while j > 0 and evil[j] != currChar:
        j = evilLPS[j - 1]
      return j + 1 if evil[j] == currChar else j

    @functools.lru_cache(None)
    def dp(i: int, matchedEvilCount: int, isS1Prefix: bool, isS2Prefix: bool) -> int:
      """
      Returns the number of good strings for s[i..n), where there're j matches
      with `evil`, `isS1Prefix` indicates if the current letter is tightly bound
      for `s1` and `isS2Prefix` indicates if the current letter is tightly bound
      for `s2`.
      """
      # s[0..i) contains `evil`, so don't consider any ongoing strings.
      if matchedEvilCount == len(evil):
        return 0
      # Run out of strings, so contribute one.
      if i == n:
        return 1
      ans = 0
      minCharIndex = ord(s1[i]) if isS1Prefix else ord('a')
      maxCharIndex = ord(s2[i]) if isS2Prefix else ord('z')
      for charIndex in range(minCharIndex, maxCharIndex + 1):
        c = chr(charIndex)
        nextMatchedEvilCount = getNextMatchedEvilCount(matchedEvilCount, c)
        ans += dp(i + 1, nextMatchedEvilCount,
                  isS1Prefix and c == s1[i],
                  isS2Prefix and c == s2[i])
        ans %= MOD
      return ans

    return dp(0, 0, True, True)

  def _getLPS(self, pattern: str) -> list[int]:
    """
    Returns the lps array, where lps[i] is the length of the longest prefix of
    pattern[0..i] which is also a suffix of this substring.
    """
    lps = [0] * len(pattern)
    j = 0
    for i in range(1, len(pattern)):
      while j > 0 and pattern[j] != pattern[i]:
        j = lps[j - 1]
      if pattern[i] == pattern[j]:
        lps[i] = j + 1
        j += 1
    return lps