Skip to content

864. Shortest Path to Get All Keys 👍

  • Time: $O(mn \cdot 2^k)$
  • Space: $O(mn \cdot 2^k)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
struct T {
  int i;
  int j;
  int keys;  // the keys in the bitmask
  T(int i, int j, int keys) : i(i), j(j), keys(keys) {}
};

class Solution {
 public:
  int shortestPathAllKeys(vector<string>& grid) {
    constexpr int dirs[4][2] = {{0, 1}, {1, 0}, {0, -1}, {-1, 0}};
    const int m = grid.size();
    const int n = grid[0].length();
    const int keysCount = getKeysCount(grid);
    const int kKeys = (1 << keysCount) - 1;
    const vector<int> start = getStart(grid);
    int ans = 0;
    queue<T> q{{{start[0], start[1], 0}}};
    vector<vector<vector<bool>>> seen(
        m, vector<vector<bool>>(n, vector<bool>(kKeys)));
    seen[start[0]][start[1]][0] = true;

    while (!q.empty()) {
      ++ans;
      for (int sz = q.size(); sz > 0; --sz) {
        const auto [i, j, keys] = q.front();
        q.pop();
        for (const auto& [dx, dy] : dirs) {
          const int x = i + dx;
          const int y = j + dy;
          if (x < 0 || x == m || y < 0 || y == n)
            continue;
          const char c = grid[x][y];
          if (c == '#')
            continue;
          const int newKeys = 'a' <= c && c <= 'f' ? keys | 1 << c - 'a' : keys;
          if (newKeys == kKeys)
            return ans;
          if (seen[x][y][newKeys])
            continue;
          if ('A' <= c && c <= 'F' && ((newKeys >> c - 'A') & 1) == 0)
            continue;
          q.emplace(x, y, newKeys);
          seen[x][y][newKeys] = true;
        }
      }
    }

    return -1;
  }

 private:
  int getKeysCount(const vector<string>& grid) {
    int count = 0;
    for (const string& s : grid)
      count += ranges::count_if(s, [](char c) { return 'a' <= c && c <= 'f'; });
    return count;
  }

  vector<int> getStart(const vector<string>& grid) {
    for (int i = 0; i < grid.size(); ++i)
      for (int j = 0; j < grid[0].length(); ++j)
        if (grid[i][j] == '@')
          return {i, j};
    throw;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class T {
  public int i;
  public int j;
  public int keys; // the keys in the bitmask
  public T(int i, int j, int keys) {
    this.i = i;
    this.j = j;
    this.keys = keys;
  }
}

class Solution {
  public int shortestPathAllKeys(String[] grid) {
    final int[][] dirs = {{0, 1}, {1, 0}, {0, -1}, {-1, 0}};
    final int m = grid.length;
    final int n = grid[0].length();
    final int keysCount = getKeysCount(grid);
    final int kKeys = (1 << keysCount) - 1;
    final int[] start = getStart(grid);
    int ans = 0;
    Queue<T> q = new ArrayDeque<>(Arrays.asList(new T(start[0], start[1], 0)));
    boolean[][][] seen = new boolean[m][n][kKeys];
    seen[start[0]][start[1]][0] = true;

    while (!q.isEmpty()) {
      ++ans;
      for (int sz = q.size(); sz > 0; --sz) {
        final int i = q.peek().i;
        final int j = q.peek().j;
        final int keys = q.poll().keys;
        for (int[] dir : dirs) {
          final int x = i + dir[0];
          final int y = j + dir[1];
          if (x < 0 || x == m || y < 0 || y == n)
            continue;
          final char c = grid[x].charAt(y);
          if (c == '#')
            continue;
          final int newKeys = 'a' <= c && c <= 'f' ? keys | 1 << c - 'a' : keys;
          if (newKeys == kKeys)
            return ans;
          if (seen[x][y][newKeys])
            continue;
          if ('A' <= c && c <= 'F' && (newKeys >> c - 'A' & 1) == 0)
            continue;
          q.offer(new T(x, y, newKeys));
          seen[x][y][newKeys] = true;
        }
      }
    }

    return -1;
  }

  private int getKeysCount(String[] grid) {
    int count = 0;
    for (final String s : grid)
      count += (int) s.chars().filter(c -> 'a' <= c && c <= 'f').count();
    return count;
  }

  private int[] getStart(String[] grid) {
    for (int i = 0; i < grid.length; ++i)
      for (int j = 0; j < grid[0].length(); ++j)
        if (grid[i].charAt(j) == '@')
          return new int[] {i, j};
    throw new IllegalArgumentException();
  }
}