Skip to content

2313. Minimum Flips in Binary Tree to Get Result 👍

  • Time: $O(n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Solution {
 public:
  int minimumFlips(TreeNode* root, bool result) {
    return dp(root, result);
  }

 private:
  struct pairHash {
    template <class T1, class T2>
    std::size_t operator()(const std::pair<T1, T2>& p) const {
      return std::hash<T1>{}(p.first) ^ std::hash<T2>{}(p.second);
    }
  };

  unordered_map<pair<TreeNode*, bool>, int, pairHash> memo;

  int dp(TreeNode* root, bool target) {
    const pair<TreeNode*, bool> key{root, target};
    if (memo.count(key))
      return memo[key];
    if (root->val == 0 || root->val == 1)  // leaf
      return root->val == target ? 0 : 1;
    if (root->val == 5)  // NOT
      return dp(root->left ? root->left : root->right, !target);

    vector<pair<int, int>> nextTargets;
    if (root->val == 2)  // OR
      nextTargets = target ? vector<pair<int, int>>{{0, 1}, {1, 0}, {1, 1}}
                           : vector<pair<int, int>>{{0, 0}};
    else if (root->val == 3)  // AND
      nextTargets = target ? vector<pair<int, int>>{{1, 1}}
                           : vector<pair<int, int>>{{0, 0}, {0, 1}, {1, 0}};
    else  // root->val == 4 (XOR)
      nextTargets = target ? vector<pair<int, int>>{{0, 1}, {1, 0}}
                           : vector<pair<int, int>>{{0, 0}, {1, 1}};

    int ans = INT_MAX;
    for (const auto& [leftTarget, rightTarget] : nextTargets)
      ans = min(ans, dp(root->left, leftTarget) + dp(root->right, rightTarget));

    return memo[key] = ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class Solution:
  def minimumFlips(self, root: Optional[TreeNode], result: bool) -> int:
    # dp(root, target) := min flips to make the subtree root become target
    @functools.lru_cache(None)
    def dp(root: Optional[TreeNode], target: bool) -> int:
      if root.val in (0, 1):  # leaf
        return 0 if root.val == target else 1
      if root.val == 5:  # NOT
        return dp(root.left or root.right, not target)
      if root.val == 2:  # OR
        nextTargets = [(0, 1), (1, 0), (1, 1)] if target else [[0, 0]]
      elif root.val == 3:  # AND
        nextTargets = [(1, 1)] if target else [(0, 0), (0, 1), (1, 0)]
      else:  # root.val == 4 XOR
        nextTargets = [(0, 1), (1, 0)] if target else [(0, 0), (1, 1)]
      return min(dp(root.left, leftTarget) + dp(root.right, rightTarget)
                 for leftTarget, rightTarget in nextTargets)

    return dp(root, result)