class Solution {
  public int countPairs(int[] nums, int k) {
    int ans = 0;
    Map<Integer, List<Integer>> numToIndices = new HashMap<>();
    for (int i = 0; i < nums.length; ++i) {
      numToIndices.putIfAbsent(nums[i], new ArrayList<>());
      numToIndices.get(nums[i]).add(i);
    }
    for (List<Integer> indices : numToIndices.values()) {
      Map<Integer, Integer> gcds = new HashMap<>();
      for (final int i : indices) {
        final int gcd_i = gcd(i, k);
        for (final int gcd_j : gcds.keySet())
          if (gcd_i * gcd_j % k == 0)
            ans += gcds.get(gcd_j);
        gcds.merge(gcd_i, 1, Integer::sum);
      }
    }
    return ans;
  }
  private int gcd(int a, int b) {
    return b == 0 ? a : gcd(b, a % b);
  }
}