Skip to content

3395. Subsequences with a Unique Middle Mode I

Approach 1: $O(n^2)$ increment counting

  • Time: $O(n^2)$
  • Space: $O(n)$
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class Solution {
 public:
  int subsequencesWithMiddleMode(vector<int>& nums) {
    const int n = nums.size();
    long ans = 0;
    unordered_map<int, int> left;
    unordered_map<int, int> right;

    for (int i = 0; i < 2; ++i)
      ++left[nums[i]];

    for (int i = 2; i < n; ++i)
      ++right[nums[i]];

    for (int i = 2; i < n - 2; ++i) {
      const int num = nums[i];
      if (--right[num] == 0)
        right.erase(num);

      const int leftCount = left[num];
      const int rightCount = right[num];
      const int leftOther = i - leftCount;
      const int rightOther = n - 1 - i - rightCount;

      // count[mode] = 5 -- [a a] a [a a]
      ans += nC2(leftCount) * nC2(rightCount);
      ans %= kMod;

      // count[mode] = 4 -- [a a] a [a ?]
      ans += nC2(leftCount) * rightCount % kMod * rightOther % kMod;
      ans %= kMod;

      // count[mode] = 4 -- [a ?] a [a a]
      ans += leftCount * leftOther % kMod * nC2(rightCount) % kMod;
      ans %= kMod;

      // count[mode] = 3 -- [a a] a [? ?]
      ans += nC2(leftCount) * nC2(rightOther) % kMod;
      ans %= kMod;

      // count[mode] = 3 -- [? ?] a [a a]
      ans += nC2(leftOther) * nC2(rightCount) % kMod;
      ans %= kMod;

      // count[mode] = 3 -- [a ?] a [a ?]
      ans +=
          leftCount * leftOther % kMod * rightCount % kMod * rightOther % kMod;
      ans %= kMod;

      // count[mode] = 2 -- [a ?] a [? ?]
      ans += leftCount * calc(num, leftOther, rightOther, left, right) % kMod;
      ans %= kMod;

      // count[mode] = 2 -- [? ?] a [a ?]
      ans += rightCount * calc(num, rightOther, leftOther, right, left) % kMod;
      ans %= kMod;

      ++left[num];
    }

    return ans % kMod;
  }

 private:
  static constexpr int kMod = 1'000'000'007;

  // Returns C(n, 2).
  long nC2(long n) {
    return n * (n - 1) / 2 % kMod;
  }

  // Returns the count of subsequences that have `a` as the middle number, where
  // invalid subsequences are excluded.
  long calc(int a, long other1, long other2,
            const unordered_map<int, int>& count1,
            const unordered_map<int, int>& count2) {
    // [a ?] a [? ?]
    long res = other1 * nC2(other2) % kMod;

    for (const auto& [b, b1] : count1) {
      if (b == a)
        continue;
      const long b2 = count2.contains(b) ? count2.at(b) : 0;
      // Exclude triples -- [a b] a [b b].
      res = (res - b1 * nC2(b2) % kMod + kMod) % kMod;
      // Exclude doubles -- [a b] a [b ?].
      res = (res - b1 * b2 % kMod * (other2 - b2) % kMod + kMod) % kMod;
    }

    for (const auto& [b, b2] : count2) {
      if (b == a)
        continue;
      const long b1 = count1.contains(b) ? count1.at(b) : 0;
      // Exclude doubles -- [a ?] a [b b].
      res = (res - (other1 - b1) * nC2(b2) % kMod + kMod) % kMod;
    }

    return res;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class Solution {
  public int subsequencesWithMiddleMode(int[] nums) {
    int n = nums.length;
    long ans = 0;
    Map<Integer, Integer> left = new HashMap<>();
    Map<Integer, Integer> right = new HashMap<>();

    for (int i = 0; i < 2; ++i)
      left.merge(nums[i], 1, Integer::sum);

    for (int i = 2; i < n; ++i)
      right.merge(nums[i], 1, Integer::sum);

    for (int i = 2; i < n - 2; ++i) {
      final int num = nums[i];
      if (right.merge(num, -1, Integer::sum) == 0)
        right.remove(num);

      final int leftCount = left.getOrDefault(num, 0);
      final int rightCount = right.getOrDefault(num, 0);
      final int leftOther = i - leftCount;
      final int rightOther = n - 1 - i - rightCount;

      // count[mode] = 5 -- [a a] a [a a]
      ans = (ans + nC2(leftCount) * nC2(rightCount)) % MOD;

      // count[mode] = 4 -- [a a] a [a ?]
      ans = (ans + nC2(leftCount) * rightCount % MOD * rightOther % MOD) % MOD;

      // count[mode] = 4 -- [a ?] a [a a]
      ans = (ans + leftCount * leftOther % MOD * nC2(rightCount) % MOD) % MOD;

      // count[mode] = 3 -- [a a] a [? ?]
      ans = (ans + nC2(leftCount) * nC2(rightOther) % MOD) % MOD;

      // count[mode] = 3 -- [? ?] a [a a]
      ans = (ans + nC2(leftOther) * nC2(rightCount) % MOD) % MOD;

      // count[mode] = 3 -- [a ?] a [a ?]
      ans = (ans + leftCount * leftOther % MOD * rightCount % MOD * rightOther % MOD) % MOD;

      // count[mode] = 2 -- [a ?] a [? ?]
      ans = (ans + leftCount * calc(num, leftOther, rightOther, left, right) % MOD) % MOD;

      // count[mode] = 2 -- [? ?] a [a ?]
      ans = (ans + rightCount * calc(num, rightOther, leftOther, right, left) % MOD) % MOD;

      // Update left map
      left.merge(num, 1, Integer::sum);
    }

    return (int) (ans % MOD);
  }

  private static final int MOD = 1_000_000_007;

  // Returns C(n, 2)
  private long nC2(long n) {
    return n * (n - 1) / 2 % MOD;
  }

  // Returns the count of subsequences that have 'a' as the middle number, where
  // invalid subsequences are excluded
  private long calc(int a, long other1, long other2, Map<Integer, Integer> count1,
                    Map<Integer, Integer> count2) {
    // [a ?] a [? ?]
    long res = other1 * nC2(other2) % MOD;

    for (Map.Entry<Integer, Integer> entry : count1.entrySet()) {
      final int b = entry.getKey();
      final long b1 = entry.getValue();
      if (b == a)
        continue;
      final long b2 = count2.getOrDefault(b, 0);
      // Exclude triples -- [a b] a [b b]
      res = (res - b1 * nC2(b2) % MOD + MOD) % MOD;
      // Exclude doubles -- [a b] a [b ?]
      res = (res - b1 * b2 % MOD * (other2 - b2) % MOD + MOD) % MOD;
    }

    for (Map.Entry<Integer, Integer> entry : count2.entrySet()) {
      final int b = entry.getKey();
      final long b2 = entry.getValue();
      if (b == a)
        continue;
      final long b1 = count1.getOrDefault(b, 0);
      // Exclude doubles -- [a ?] a [b b]
      res = (res - (other1 - b1) * nC2(b2) % MOD + MOD) % MOD;
    }

    return res;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class Solution:
  def __init__(self):
    self.MOD = 1_000_000_007

  def subsequencesWithMiddleMode(self, nums: list[int]) -> int:
    n = len(nums)
    ans = 0
    left = collections.Counter()
    right = collections.Counter()

    for i in range(2):
      left[nums[i]] += 1

    for i in range(2, n):
      right[nums[i]] += 1

    for i in range(2, n - 2):
      num = nums[i]
      right[num] -= 1
      if right[num] == 0:
        del right[num]

      leftCount = left[num]
      rightCount = right[num]
      leftOther = i - leftCount
      rightOther = n - 1 - i - rightCount

      # count[mode] = 5 -- [a a] a [a a]
      ans += math.comb(leftCount, 2) * math.comb(rightCount, 2)

      # count[mode] = 4 -- [a a] a [a ?]
      ans += math.comb(leftCount, 2) * rightCount * rightOther

      # count[mode] = 4 -- [a ?] a [a a]
      ans += leftCount * leftOther * math.comb(rightCount, 2)

      # count[mode] = 3 -- [a a] a [? ?]
      ans += math.comb(leftCount, 2) * math.comb(rightOther, 2)

      # count[mode] = 3 -- [? ?] a [a a]
      ans += math.comb(leftOther, 2) * math.comb(rightCount, 2)

      # count[mode] = 3 -- [a ?] a [a ?]
      ans += leftCount * leftOther * rightCount * rightOther

      # count[mode] = 2 -- [a ?] a [? ?]
      ans += leftCount * self._calc(num, leftOther, rightOther, left, right)

      # count[mode] = 2 -- [? ?] a [a ?]
      ans += rightCount * self._calc(num, rightOther, leftOther, right, left)

      ans %= self.MOD
      left[num] += 1

    return ans

  def _calc(
      self,
      a: int,
      other1: int,
      other2: int,
      count1: dict[int, int],
      count2: dict[int, int]
  ) -> int:
    """
    Returns the count of subsequences that have `a` as the middle number, where
    invalid subsequences are excluded.
    """
    # [a ?] a [? ?]
    res = (other1 * math.comb(other2, 2)) % self.MOD

    for b, b1 in count1.items():
      if b == a:
        continue
      b2 = count2[b]
      # Exclude triples -- [a b] a [b b].
      res = (res - b1 * math.comb(b2, 2)) % self.MOD
      # Exclude doubles -- [a b] a [b ?].
      res = (res - b1 * b2 * (other2 - b2)) % self.MOD

    for b, b2 in count2.items():
      if b == a:
        continue
      b1 = count1[b]
      # Exclude doubles -- [a ?] a [b b].
      res = (res - (other1 - b1) * math.comb(b2, 2)) % self.MOD

    return (res + self.MOD) % self.MOD

Approach 2: $O(n^2)$ complement counting

  • Time: $O(n^2)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class Solution {
 public:
  int subsequencesWithMiddleMode(vector<int>& nums) {
    int ans = 0;
    unordered_map<int, int> p;  // prefix counter
    unordered_map<int, int> s;  // suffix counter

    for (const int num : nums)
      ++s[num];

    for (int i = 0; i < nums.size(); ++i) {
      const int a = nums[i];
      --s[a];

      const int l = i;
      const int r = nums.size() - i - 1;

      // Start with all possible subsequences with `a` as the middle number.
      ans = (ans + nC2(l) * nC2(r)) % kMod;

      // Minus cases where frequency of `a` is 1, so it's not a mode.
      ans = (ans - nC2(l - p[a]) * nC2(r - s[a])) % kMod;

      for (const int b : getUniqueNums(p, s)) {
        if (b == a)
          continue;

        // Minus cases where the middle number is not a "unique" mode
        int subtract = 0;

        // [a b] a [b c]
        subtract = (subtract + p[a] * p[b] * s[b] * (r - s[a] - s[b])) % kMod;

        // [b c] a [a b]
        subtract = (subtract + s[a] * s[b] * p[b] * (l - p[a] - p[b])) % kMod;

        // [b b] a [a c]
        subtract = (subtract + nC2(p[b]) * s[a] * (r - s[a] - s[b])) % kMod;

        // [a c] a [b b]
        subtract = (subtract + nC2(s[b]) * p[a] * (l - p[a] - p[b])) % kMod;

        // [b b] a [a b]
        subtract = (subtract + nC2(p[b]) * s[a] * s[b]) % kMod;

        // [a b] a [b b]
        subtract = (subtract + nC2(s[b]) * p[a] * p[b]) % kMod;

        ans = (ans - subtract + kMod) % kMod;
      }

      ++p[a];
    }

    return (ans + kMod) % kMod;
  }

 private:
  static constexpr int kMod = 1'000'000'007;

  // Returns C(n, 2).
  long nC2(long n) {
    return n * (n - 1) / 2 % kMod;
  }

  unordered_set<int> getUniqueNums(const unordered_map<int, int>& p,
                                   const unordered_map<int, int>& s) {
    unordered_set<int> uniqueNums;
    for (const auto& [num, _] : p)
      uniqueNums.insert(num);
    for (const auto& [num, _] : s)
      uniqueNums.insert(num);
    return uniqueNums;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class Solution {
  public int subsequencesWithMiddleMode(int[] nums) {
    int ans = 0;
    Map<Integer, Integer> p = new HashMap<>(); // prefix counter
    Map<Integer, Integer> s = new HashMap<>(); // suffix counter

    for (final int num : nums)
      s.merge(num, 1, Integer::sum);

    for (int i = 0; i < nums.length; ++i) {
      final int a = nums[i];
      s.merge(a, -1, Integer::sum);

      final int l = i;
      final int r = nums.length - i - 1;

      final int pa = p.getOrDefault(a, 0);
      final int sa = s.get(a);

      // Start with all possible subsequences with `a` as the middle number.
      ans = (int) ((ans + (long) nC2(l) * nC2(r)) % MOD);

      // Minus cases where frequency of 'a' is 1, so it's not a mode.
      ans = (int) ((ans - (long) nC2(l - pa) * nC2(r - sa)) % MOD);

      for (final int b : getUniqueNums(p, s)) {
        if (b == a)
          continue;

        final int pb = p.getOrDefault(b, 0);
        final int sb = s.get(b);

        // Minus cases where the middle number is not a "unique" mode
        int subtract = 0;

        // [a b] a [b c]
        subtract = (int) ((subtract + (long) pa * pb * sb * (r - sa - sb)) % MOD);

        // [b c] a [a b]
        subtract = (int) ((subtract + (long) sa * sb * pb * (l - pa - pb)) % MOD);

        // [b b] a [a c]
        subtract = (int) ((subtract + (long) nC2(pb) * sa * (r - sa - sb)) % MOD);

        // [a c] a [b b]
        subtract = (int) ((subtract + (long) nC2(sb) * pa * (l - pa - pb)) % MOD);

        // [b b] a [a b]
        subtract = (int) ((subtract + (long) nC2(pb) * sa * sb) % MOD);

        // [a b] a [b b]
        subtract = (int) ((subtract + (long) nC2(sb) * pa * pb) % MOD);

        ans = (int) ((ans - subtract + MOD) % MOD);
      }

      p.merge(a, 1, Integer::sum);
    }

    return (ans + MOD) % MOD;
  }

  private static final int MOD = 1_000_000_007;

  private int nC2(long n) {
    return (int) (n * (n - 1) / 2 % MOD);
  }

  private Set<Integer> getUniqueNums(final Map<Integer, Integer> p, final Map<Integer, Integer> s) {
    final Set<Integer> uniqueNums = new HashSet<>();
    uniqueNums.addAll(p.keySet());
    uniqueNums.addAll(s.keySet());
    return uniqueNums;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution:
  def subsequencesWithMiddleMode(self, nums: list[int]) -> int:
    MOD = 1_000_000_007
    ans = 0
    p = collections.Counter()  # prefix counter
    s = collections.Counter(nums)  # suffix counter

    def nC2(n: int) -> int:
      return n * (n - 1) // 2

    for i, a in enumerate(nums):
      s[a] -= 1

      l = i
      r = len(nums) - i - 1

      # Start with all possible subsequences with `a` as the middle number.
      ans += nC2(l) * nC2(r)

      # Minus the cases where the frequency of `a` is 1, so it's not a mode.
      ans -= nC2(l - p[a]) * nC2(r - s[a])

      for b in p | s:
        if b == a:
          continue

        # Minus the cases where the middle number is not a "unique" mode.
        ans -= p[a] * p[b] * s[b] * (r - s[a] - s[b])  # [a b] a [b c]
        ans -= s[a] * s[b] * p[b] * (l - p[a] - p[b])  # [b c] a [a b]
        ans -= nC2(p[b]) * s[a] * (r - s[a] - s[b])  # [b b] a [a c]
        ans -= nC2(s[b]) * p[a] * (l - p[a] - p[b])  # [a c] a [b b]

        # Minus the cases where the middle number is not a mode.
        ans -= nC2(p[b]) * s[a] * s[b]  # [b b] a [a b]
        ans -= nC2(s[b]) * p[a] * p[b]  # [a b] a [b b]

      ans %= MOD
      p[a] += 1

    return ans

Approach 3: $O(n)$ complement counting

  • Time: $O(n)$
  • Space: $O(n)$
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
// Recall from solution 1 that after counting all the subsequences with `a` as
// the middle mode number, we need to subtract the cases where `a` is not a
// unique mode or not a mode.
//
// To avoid the need of looping through all numbers that are not `a`, we can
// maintain the sums that are not related to `a` in the loop.
//
// So, during the simplification of the formula, keep the running sums of
// pss, spp, pp, ss, and ps as the first item.
// (for cleaner notation, abbreviate p[b] and s[b] to just p and s)
//
//   sum(b != a) (p[a] * p * s) * (r - sa - s)
//             + (sa * s * p) * (l - p[a] - p)
//             + (p, 2) * sa * (r - sa)
//             + (s, 2) * p[a] * (l - p[a])
//
//   sum(b != a) (p * s) * (p[a] * (r - sa)) + (p * s^2) * (-p[a])
//             + (s * p) * (sa * (l - p[a])) + (s * p^2) * (-sa)
//             + (p^2 - p) * (sa * (r - sa) / 2)
//             + (s^2 - s) * (p[a] * (l - p[a]) / 2)

class Solution {
 public:
  int subsequencesWithMiddleMode(vector<int>& nums) {
    int ans = 0;
    unordered_map<int, int> p;  // prefix counter
    unordered_map<int, int> s;  // suffix counter

    for (const int num : nums)
      ++s[num];

    long pss = 0;
    long spp = 0;
    long pp = 0;
    long ss = 0;
    long ps = 0;

    for (const auto& [_, freq] : s)
      ss = (ss + static_cast<long>(freq) * freq) % kMod;

    for (int i = 0; i < nums.size(); ++i) {
      const int a = nums[i];
      long sa = s[a];
      const long pa = p[a];

      // Update running sums after decrementing sa.
      pss = (pss + pa * (-sa * sa + (sa - 1) * (sa - 1))) % kMod;
      spp = (spp - pa * pa) % kMod;  // (-sa + (sa - 1)) * pa * pa
      ss = (ss - sa * sa + (sa - 1) * (sa - 1)) % kMod;
      ps = (ps - pa) % kMod;  // -pa * (-sa + (sa - 1))

      sa = --s[a];

      const int l = i;
      const int r = nums.size() - i - 1;

      // Start with all possible subsequences with `a` as the middle number.
      ans = (ans + nC2(l) * nC2(r)) % kMod;

      // Minus cases where frequency of `a` is 1, so it's not a mode.
      ans = (ans - nC2(l - pa) * nC2(r - sa)) % kMod;

      // Minus the values where `b != a`.
      const long pss_ = (pss - pa * sa * sa) % kMod;
      const long spp_ = (spp - sa * pa * pa) % kMod;
      const long pp_ = (pp - pa * pa) % kMod;
      const long ss_ = (ss - sa * sa) % kMod;
      const long ps_ = (ps - pa * sa) % kMod;
      const long p_ = l - pa;
      const long s_ = r - sa;

      // Minus cases where `a` is not a "unique" mode or not a mode.
      long subtract = 0;
      subtract = (subtract + ps_ * (pa * (r - sa))) % kMod;
      subtract = (subtract + pss_ * (-pa)) % kMod;
      subtract = (subtract + ps_ * (sa * (l - pa))) % kMod;
      subtract = (subtract + spp_ * (-sa)) % kMod;
      subtract = (subtract + (pp_ - p_) * sa * (r - sa) / 2) % kMod;
      subtract = (subtract + (ss_ - s_) * pa * (l - pa) / 2) % kMod;
      ans = (ans - subtract + kMod) % kMod;

      // Update running sums after incrementing p[a].
      pss = (pss + sa * sa) % kMod;  // (-pa + (pa + 1)) * sa * sa
      spp = (spp + sa * (-pa * pa + (pa + 1) * (pa + 1))) % kMod;
      pp = (pp - pa * pa + (pa + 1) * (pa + 1)) % kMod;
      ps = (ps + sa) % kMod;  // (-pa + (pa + 1)) * sa

      ++p[a];
    }

    return (ans + kMod) % kMod;
  }

 private:
  static constexpr int kMod = 1'000'000'007;

  // Returns C(n, 2)
  long nC2(long n) {
    return n * (n - 1) / 2 % kMod;
  }
};
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
// Recall from solution 1 that after counting all the subsequences with `a` as
// the middle mode number, we need to subtract the cases where `a` is not a
// unique mode or not a mode.
//
// To avoid the need of looping through all numbers that are not `a`, we can
// maintain the sums that are not related to `a` in the loop.
//
// So, during the simplification of the formula, keep the running sums of
// pss, spp, pp, ss, and ps as the first item.
// (for cleaner notation, abbreviate p[b] and s[b] to just p and s)
//
//   sum(b != a) (p[a] * p * s) * (r - s[a] - s)
//             + (s[a] * s * p) * (l - p[a] - p)
//             + (p, 2) * s[a] * (r - s[a])
//             + (s, 2) * p[a] * (l - p[a])
//
//   sum(b != a) (p * s) * (p[a] * (r - s[a])) + (p * s^2) * (-p[a])
//             + (s * p) * (s[a] * (l - p[a])) + (s * p^2) * (-s[a])
//             + (p^2 - p) * (s[a] * (r - s[a]) / 2)
//             + (s^2 - s) * (p[a] * (l - p[a]) / 2)

class Solution {
  public int subsequencesWithMiddleMode(int[] nums) {
    int ans = 0;
    Map<Integer, Integer> p = new HashMap<>(); // prefix counter
    Map<Integer, Integer> s = new HashMap<>(); // suffix counter

    for (final int num : nums)
      s.merge(num, 1, Integer::sum);

    long pss = 0;
    long spp = 0;
    long pp = 0;
    long ss = 0;
    long ps = 0;

    for (final int freq : s.values())
      ss = (ss + (long) freq * freq) % MOD;

    for (int i = 0; i < nums.length; ++i) {
      final int a = nums[i];
      long sa = s.get(a);
      final long pa = p.getOrDefault(a, 0);

      // Update running sums after decrementing s[a]
      pss = (pss + pa * (-sa * sa + (sa - 1) * (sa - 1))) % MOD;
      spp = (spp - pa * pa) % MOD; // (-sa + (sa - 1)) * pa * pa
      ss = (ss - sa * sa + (sa - 1) * (sa - 1)) % MOD;
      ps = (ps - pa) % MOD; // -pa * (-sa + (sa - 1))

      s.merge(a, -1, Integer::sum);
      sa = s.get(a);

      final int l = i;
      final int r = nums.length - i - 1;

      // Start with all possible subsequences with `a` as the middle number
      ans = (int) ((ans + (long) nC2(l) * nC2(r)) % MOD);

      // Minus cases where frequency of `a` is 1, so it's not a mode
      ans = (int) ((ans - (long) nC2(l - pa) * nC2(r - sa)) % MOD);

      // Minus the values where `b != a`
      final long pss_ = (pss - pa * sa * sa) % MOD;
      final long spp_ = (spp - sa * pa * pa) % MOD;
      final long pp_ = (pp - pa * pa) % MOD;
      final long ss_ = (ss - sa * sa) % MOD;
      final long ps_ = (ps - pa * sa) % MOD;
      final long p_ = l - pa;
      final long s_ = r - sa;

      // Minus cases where `a` is not a "unique" mode or not a mode
      long subtract = 0;
      subtract = (subtract + ps_ * (pa * (r - sa))) % MOD;
      subtract = (subtract + pss_ * (-pa)) % MOD;
      subtract = (subtract + ps_ * (sa * (l - pa))) % MOD;
      subtract = (subtract + spp_ * (-sa)) % MOD;
      subtract = (subtract + (pp_ - p_) * sa * (r - sa) / 2) % MOD;
      subtract = (subtract + (ss_ - s_) * pa * (l - pa) / 2) % MOD;
      ans = (int) ((ans - subtract + MOD) % MOD);

      // Update running sums after incrementing p[a]
      pss = (pss + sa * sa) % MOD; // (-pa + (pa + 1)) * sa * sa
      spp = (spp + sa * (-pa * pa + (pa + 1) * (pa + 1))) % MOD;
      pp = (pp - pa * pa + (pa + 1) * (pa + 1)) % MOD;
      ps = (ps + sa) % MOD; // (-pa + (pa + 1)) * sa

      p.merge(a, 1, Integer::sum);
    }

    return (int) ((ans + MOD) % MOD);
  }

  private static final int MOD = 1_000_000_007;

  // Returns C(n, 2)
  private long nC2(long n) {
    return n * (n - 1) / 2 % MOD;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Recall from solution 1 that after counting all the subsequences with `a` as
# the middle mode number, we need to subtract the cases where `a` is not a
# unique mode or not a mode.
#
# To avoid the need of looping through all numbers that are not `a`, we can
# maintain the sums that are not related to `a` in the loop.
#
# So, during the simplification of the formula, keep the running sums of
# pss, spp, pp, ss, and ps as the first item.
# (for cleaner notation, abbreviate p[b] and s[b] to just p and s)
#
#   sum(b != a) (p[a] * p * s) * (r - s[a] - s)
#             + (s[a] * s * p) * (l - p[a] - p)
#             + (p, 2) * s[a] * (r - s[a])
#             + (s, 2) * p[a] * (l - p[a])
#
#   sum(b != a) (p * s) * (p[a] * (r - s[a])) + (p * s^2) * (-p[a])
#             + (s * p) * (s[a] * (l - p[a])) + (s * p^2) * (-s[a])
#             + (p^2 - p) * (s[a] * (r - s[a]) / 2)
#             + (s^2 - s) * (p[a] * (l - p[a]) / 2)


class Solution:
  def subsequencesWithMiddleMode(self, nums: list[int]) -> int:
    MOD = 1_000_000_007
    ans = 0
    p = collections.Counter()  # prefix counter
    s = collections.Counter(nums)  # suffix counter

    def nC2(n: int) -> int:
      return n * (n - 1) // 2

    pss = 0
    spp = 0
    pp = 0
    ss = sum(freq**2 for freq in s.values())
    ps = 0

    for i, a in enumerate(nums):
      # Update running sums after decrementing s[a].
      pss += p[a] * (-s[a]**2 + (s[a] - 1)**2)
      spp += -p[a]**2  # (-s[a] + (s[a] - 1)) * p[a]**2
      ss += -s[a]**2 + (s[a] - 1)**2
      ps += -p[a]  # -p[a] * (-s[a] + (s[a] - 1))

      s[a] -= 1

      l = i
      r = len(nums) - i - 1

      # Start with all possible subsequences with `a` as the middle number.
      ans += nC2(l) * nC2(r)

      # Minus the cases where the frequency of `a` is 1, so it's not a mode.
      ans -= nC2(l - p[a]) * nC2(r - s[a])

      # Minus the values where `b != a`.
      pss_ = pss - p[a] * s[a]**2
      spp_ = spp - s[a] * p[a]**2
      pp_ = pp - p[a]**2
      ss_ = ss - s[a]**2
      ps_ = ps - p[a] * s[a]
      p_ = l - p[a]
      s_ = r - s[a]

      # Minus the cases where the `a` is not a "unique" mode or not a mode.
      ans -= ps_ * (p[a] * (r - s[a])) + pss_ * (-p[a])
      ans -= ps_ * (s[a] * (l - p[a])) + spp_ * (-s[a])
      ans -= (pp_ - p_) * s[a] * (r - s[a]) // 2
      ans -= (ss_ - s_) * p[a] * (l - p[a]) // 2
      ans %= MOD

      # Update running sums after incrementing p[a].
      pss += s[a]**2  # (-p[a] + (p[a] + 1)) * s[a]**2
      spp += s[a] * (-p[a]**2 + (p[a] + 1)**2)
      pp += -p[a]**2 + (p[a] + 1)**2
      ps += s[a]  # (-p[a] + (p[a] + 1)) * s[a]

      p[a] += 1

    return ans