Skip to content

528. Random Pick with Weight

  • Time: Constructor: $O(n)$, pickIndex(): $O(\log n)$
  • Space: $O(n)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution {
 public:
  Solution(vector<int>& w) : prefix(w.size()) {
    partial_sum(w.begin(), w.end(), prefix.begin());
  }

  int pickIndex() {
    const int target = rand() % prefix.back();
    int l = 0;
    int r = prefix.size();

    while (l < r) {
      const int m = (l + r) / 2;
      if (prefix[m] > target)
        r = m;
      else
        l = m + 1;
    }

    return l;
  }

 private:
  vector<int> prefix;
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
  public Solution(int[] w) {
    prefix = w;
    for (int i = 1; i < prefix.length; ++i)
      prefix[i] += prefix[i - 1];
  }

  public int pickIndex() {
    final int target = rand.nextInt(prefix[prefix.length - 1]);
    int l = 0;
    int r = prefix.length;

    while (l < r) {
      final int m = (l + r) / 2;
      if (prefix[m] > target)
        r = m;
      else
        l = m + 1;
    }

    return l;
  }

  private int[] prefix;
  private Random rand = new Random();
}
1
2
3
4
5
6
7
8
class Solution:
  def __init__(self, w: List[int]):
    self.prefix = list(itertools.accumulate(w))

  def pickIndex(self) -> int:
    target = random.randint(0, self.prefix[-1] - 1)
    return bisect.bisect_right(range(len(self.prefix)), target,
                               key=lambda m: self.prefix[m])

Approach 2: Built-in

  • Time: Constructor: $O(n)$, pickIndex(): $O(\log n)$
  • Space: undefined
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
class Solution {
 public:
  Solution(vector<int>& w) : prefix(w.size()) {
    partial_sum(w.begin(), w.end(), prefix.begin());
  }

  int pickIndex() {
    const int target = rand() % prefix.back();
    return ranges::upper_bound(prefix, target) - prefix.begin();
  }

 private:
  vector<int> prefix;
};
1
2
3
4
5
6
class Solution:
  def __init__(self, w: List[int]):
    self.prefix = list(itertools.accumulate(w))

  def pickIndex(self) -> int:
    return bisect_left(self.prefix, random.random() * self.prefix[-1])