Skip to content

3501. Maximize Active Section with Trade II

  • Time: $O(n\log n)$
  • Space: $O(n\log n)$
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <ranges>

struct Group {
  int start;
  int length;
};

class SparseTable {
 public:
  SparseTable(const vector<int>& nums)
      : n(nums.size()), st(std::bit_width(n) + 1, vector<int>(n + 1)) {
    copy(nums.begin(), nums.end(), st[0].begin());
    for (int i = 1; i <= bit_width(n); ++i)
      for (int j = 0; j + (1 << i) <= n; ++j)
        st[i][j] = max(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]);
  }

  // Returns max(nums[l..r]).
  int query(unsigned l, unsigned r) const {
    const int i = bit_width(r - l + 1) - 1;
    return max(st[i][l], st[i][r - (1 << i) + 1]);
  }

 private:
  const unsigned n;
  vector<vector<int>> st;  // st[i][j] := max(nums[j..j + 2^i - 1])
};

class Solution {
 public:
  vector<int> maxActiveSectionsAfterTrade(string s,
                                          vector<vector<int>>& queries) {
    const int n = s.length();
    const int ones = ranges::count(s, '1');
    const auto [zeroGroups, zeroGroupIndex] = getZeroGroups(s);
    if (zeroGroups.empty())
      return vector<int>(queries.size(), ones);

    const SparseTable st(getZeroMergeLengths(zeroGroups));
    vector<int> ans;

    for (const vector<int>& query : queries) {
      const int l = query[0];
      const int r = query[1];
      const int left = zeroGroupIndex[l] == -1
                           ? -1
                           : (zeroGroups[zeroGroupIndex[l]].length -
                              (l - zeroGroups[zeroGroupIndex[l]].start));
      const int right = zeroGroupIndex[r] == -1
                            ? -1
                            : (r - zeroGroups[zeroGroupIndex[r]].start + 1);
      const auto [startAdjacentGroupIndex, endAdjacentGroupIndex] =
          mapToAdjacentGroupIndices(
              zeroGroupIndex[l] + 1,
              s[r] == '1' ? zeroGroupIndex[r] : zeroGroupIndex[r] - 1);
      int activeSections = ones;
      if (s[l] == '0' && s[r] == '0' &&
          zeroGroupIndex[l] + 1 == zeroGroupIndex[r])
        activeSections = max(activeSections, ones + left + right);
      else if (startAdjacentGroupIndex <= endAdjacentGroupIndex)
        activeSections = max(
            activeSections,
            ones + st.query(startAdjacentGroupIndex, endAdjacentGroupIndex));
      if (s[l] == '0' &&
          zeroGroupIndex[l] + 1 <=
              (s[r] == '1' ? zeroGroupIndex[r] : zeroGroupIndex[r] - 1))
        activeSections =
            max(activeSections,
                ones + left + zeroGroups[zeroGroupIndex[l] + 1].length);
      if (s[r] == '0' && zeroGroupIndex[l] < zeroGroupIndex[r] - 1)
        activeSections =
            max(activeSections,
                ones + right + zeroGroups[zeroGroupIndex[r] - 1].length);
      ans.push_back(activeSections);
    }

    return ans;
  }

 private:
  // Returns the zero groups and the index of the zero group that contains the
  // i-th character.
  pair<vector<Group>, vector<int>> getZeroGroups(const string& s) {
    vector<Group> zeroGroups;
    vector<int> zeroGroupIndex;
    for (int i = 0; i < s.length(); i++) {
      if (s[i] == '0') {
        if (i > 0 && s[i - 1] == '0')
          ++zeroGroups.back().length;
        else
          zeroGroups.push_back({i, 1});
      }
      zeroGroupIndex.push_back(zeroGroups.size() - 1);
    }
    return {zeroGroups, zeroGroupIndex};
  }

  // Returns the sums of the lengths of the adjacent groups.
  vector<int> getZeroMergeLengths(const vector<Group>& zeroGroups) {
    vector<int> zeroMergeLengths;
    for (const auto& [a, b] : zeroGroups | views::pairwise)
      zeroMergeLengths.push_back(a.length + b.length);
    return zeroMergeLengths;
  }

  // Returns the indices of the adjacent groups that contain l and r completely.
  //
  // e.g.    groupIndices = [0, 1, 2, 3]
  // adjacentGroupIndices = [0 (0, 1), 1 (1, 2), 2 (2, 3)]
  // map(startGroupIndex = 1, endGroupIndex = 3) -> (1, 2)
  pair<int, int> mapToAdjacentGroupIndices(int startGroupIndex,
                                           int endGroupIndex) {
    return {startGroupIndex, endGroupIndex - 1};
  }
};
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class Group {
  public int start;
  public int length;
  public Group(int start, int length) {
    this.start = start;
    this.length = length;
  }
}

class SparseTable {
  public SparseTable(int[] nums) {
    n = nums.length;
    st = new int[bitLength(n) + 1][n + 1];
    System.arraycopy(nums, 0, st[0], 0, n);
    for (int i = 1; i <= st.length; ++i)
      for (int j = 0; j + (1 << i) <= n; ++j)
        st[i][j] = Math.max(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]);
  }

  // Returns max(nums[l..r])
  public int query(int l, int r) {
    final int i = bitLength(r - l + 1) - 1;
    return Math.max(st[i][l], st[i][r - (1 << i) + 1]);
  }

  private final int n;
  private final int[][] st; // st[i][j] := max(nums[j..j + 2^i - 1])

  private int bitLength(int n) {
    return 32 - Integer.numberOfLeadingZeros(n);
  }
}

class Solution {
  public List<Integer> maxActiveSectionsAfterTrade(String s, int[][] queries) {
    final int n = s.length();
    final int ones = (int) s.chars().filter(c -> c == '1').count();
    final Pair<List<Group>, int[]> zeroGroupsInfo = getZeroGroups(s);
    final List<Group> zeroGroups = zeroGroupsInfo.getKey();
    final int[] zeroGroupIndex = zeroGroupsInfo.getValue();

    if (zeroGroups.isEmpty())
      return Collections.nCopies(queries.length, ones);

    final SparseTable st = new SparseTable(getZeroMergeLengths(zeroGroups));
    final List<Integer> ans = new ArrayList<>();

    for (int[] query : queries) {
      final int l = query[0];
      final int r = query[1];
      final int left = zeroGroupIndex[l] == -1 ? -1
                                               : (zeroGroups.get(zeroGroupIndex[l]).length -
                                                  (l - zeroGroups.get(zeroGroupIndex[l]).start));
      final int right =
          zeroGroupIndex[r] == -1 ? -1 : (r - zeroGroups.get(zeroGroupIndex[r]).start + 1);
      final Pair<Integer, Integer> adjacentIndices = mapToAdjacentGroupIndices(
          zeroGroupIndex[l] + 1, s.charAt(r) == '1' ? zeroGroupIndex[r] : zeroGroupIndex[r] - 1);
      final int startAdjacentGroupIndex = adjacentIndices.getKey();
      final int endAdjacentGroupIndex = adjacentIndices.getValue();

      int activeSections = ones;
      if (s.charAt(l) == '0' && s.charAt(r) == '0' && zeroGroupIndex[l] + 1 == zeroGroupIndex[r])
        activeSections = Math.max(activeSections, ones + left + right);
      else if (startAdjacentGroupIndex <= endAdjacentGroupIndex)
        activeSections = Math.max(activeSections,
                                  ones + st.query(startAdjacentGroupIndex, endAdjacentGroupIndex));
      if (s.charAt(l) == '0' &&
          zeroGroupIndex[l] + 1 <= (s.charAt(r) == '1' ? zeroGroupIndex[r] : zeroGroupIndex[r] - 1))
        activeSections =
            Math.max(activeSections, ones + left + zeroGroups.get(zeroGroupIndex[l] + 1).length);
      if (s.charAt(r) == '0' && zeroGroupIndex[l] < zeroGroupIndex[r] - 1)
        activeSections =
            Math.max(activeSections, ones + right + zeroGroups.get(zeroGroupIndex[r] - 1).length);
      ans.add(activeSections);
    }

    return ans;
  }

  // Returns the zero groups and the index of the zero group that contains the i-th character
  private Pair<List<Group>, int[]> getZeroGroups(String s) {
    final List<Group> zeroGroups = new ArrayList<>();
    final int[] zeroGroupIndex = new int[s.length()];

    for (int i = 0; i < s.length(); i++) {
      if (s.charAt(i) == '0') {
        if (i > 0 && s.charAt(i - 1) == '0')
          zeroGroups.get(zeroGroups.size() - 1).length++;
        else
          zeroGroups.add(new Group(i, 1));
      }
      zeroGroupIndex[i] = zeroGroups.size() - 1;
    }

    return new Pair<>(zeroGroups, zeroGroupIndex);
  }

  // Returns the sums of the lengths of the adjacent groups
  private int[] getZeroMergeLengths(List<Group> zeroGroups) {
    final int[] zeroMergeLengths = new int[zeroGroups.size() - 1];
    for (int i = 0; i < zeroGroups.size() - 1; ++i)
      zeroMergeLengths[i] = zeroGroups.get(i).length + zeroGroups.get(i + 1).length;
    return zeroMergeLengths;
  }

  // Returns the indices of the adjacent groups that contain l and r completely
  private Pair<Integer, Integer> mapToAdjacentGroupIndices(int startGroupIndex, int endGroupIndex) {
    return new Pair<>(startGroupIndex, endGroupIndex - 1);
  }
}
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
from dataclasses import dataclass


@dataclass
class Group:
  start: int
  length: int


class SparseTable:
  def __init__(self, nums: list[int]):
    self.n = len(nums)
    # st[i][j] := max(nums[j..j + 2^i - 1])
    self.st = [[0] * (self.n + 1) for _ in range(self.n.bit_length() + 1)]
    self.st[0] = nums.copy()
    for i in range(1, self.n.bit_length() + 1):
      for j in range(self.n - (1 << i) + 1):
        self.st[i][j] = max(
            self.st[i - 1][j],
            self.st[i - 1][j + (1 << (i - 1))])

  def query(self, l: int, r: int) -> int:
    """Returns max(nums[l..r])."""
    i = (r - l + 1).bit_length() - 1
    return max(self.st[i][l], self.st[i][r - (1 << i) + 1])


class Solution:
  def maxActiveSectionsAfterTrade(
      self,
      s: str,
      queries: list[list[int]]
  ) -> list[int]:
    ones = s.count('1')
    zeroGroups, zeroGroupIndex = self._getZeroGroups(s)
    if not zeroGroups:
      return [ones] * len(queries)

    st = SparseTable(self._getZeroMergeLengths(zeroGroups))

    def getMaxActiveSections(l: int, r: int) -> int:
      left = (-1 if zeroGroupIndex[l] == -1
              else (zeroGroups[zeroGroupIndex[l]].length -
                    (l - zeroGroups[zeroGroupIndex[l]].start)))
      right = (-1 if zeroGroupIndex[r] == -1
               else (r - zeroGroups[zeroGroupIndex[r]].start + 1))
      startAdjacentGroupIndex, endAdjacentGroupIndex = self._mapToAdjacentGroupIndices(
          zeroGroupIndex[l] + 1, zeroGroupIndex[r] if s[r] == '1' else zeroGroupIndex[r] - 1)
      activeSections = ones
      if (s[l] == '0' and s[r] == '0' and
              zeroGroupIndex[l] + 1 == zeroGroupIndex[r]):
        activeSections = max(activeSections, ones + left + right)
      elif startAdjacentGroupIndex <= endAdjacentGroupIndex:
        activeSections = max(
            activeSections,
            ones + st.query(startAdjacentGroupIndex, endAdjacentGroupIndex))
      if (s[l] == '0' and
          zeroGroupIndex[l] + 1 <= (zeroGroupIndex[r]
                                    if s[r] == '1' else zeroGroupIndex[r] - 1)):
        activeSections = max(activeSections, ones + left +
                             zeroGroups[zeroGroupIndex[l] + 1].length)
      if (s[r] == '0' and zeroGroupIndex[l] < zeroGroupIndex[r] - 1):
        activeSections = max(activeSections, ones + right +
                             zeroGroups[zeroGroupIndex[r] - 1].length)
      return activeSections

    return [getMaxActiveSections(l, r) for l, r in queries]

  def _getZeroGroups(self, s: str) -> tuple[list[Group], list[int]]:
    """
    Returns the zero groups and the index of the zero group that contains the
    i-th character.
    """
    zeroGroups = []
    zeroGroupIndex = []
    for i in range(len(s)):
      if s[i] == '0':
        if i > 0 and s[i - 1] == '0':
          zeroGroups[-1].length += 1
        else:
          zeroGroups.append(Group(i, 1))
      zeroGroupIndex.append(len(zeroGroups) - 1)
    return zeroGroups, zeroGroupIndex

  def _getZeroMergeLengths(self, zeroGroups: list[Group]) -> list[int]:
    """Returns the sums of the lengths of the adjacent groups."""
    return [a.length + b.length for a, b in itertools.pairwise(zeroGroups)]

  def _mapToAdjacentGroupIndices(
      self,
      startGroupIndex: int,
      endGroupIndex: int
  ) -> tuple[int, int]:
    """
    Returns the indices of the adjacent groups that contain l and r completely.

    e.g.    groupIndices = [0, 1, 2, 3]
    adjacentGroupIndices = [0 (0, 1), 1 (1, 2), 2 (2, 3)]
    map(startGroupIndex = 1, endGroupIndex = 3) -> (1, 2)
    """
    return startGroupIndex, endGroupIndex - 1