Skip to content

3508. Implement Router 👍

  • Time:
    • Constructor: $O(1)$
    • addPacket(source: int, destination: int, timestamp: int): $O(1)$
    • forwardPacket(): $O(1)$
    • getCount(destination: int, startTime: int, endTime: int): $O(\log n)$
  • Space: $O(|\texttt{addPacket(source: int, destination: int, timestamp: int)}|)$
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
struct Packet {
  int source;
  int destination;
  int timestamp;

  bool operator<(const Packet& other) const {
    return source < other.source ||
           (source == other.source && destination < other.destination) ||
           (source == other.source && destination == other.destination &&
            timestamp < other.timestamp);
  }
};

class Router {
 public:
  Router(int memoryLimit) : memoryLimit(memoryLimit) {}

  bool addPacket(int source, int destination, int timestamp) {
    const Packet packet{source, destination, timestamp};
    if (uniquePackets.find(packet) != uniquePackets.end())
      return false;
    if (packetQueue.size() == memoryLimit)
      forwardPacket();
    packetQueue.push(packet);
    uniquePackets.insert(packet);
    destinationTimestamps[destination].push_back(timestamp);
    return true;
  }

  vector<int> forwardPacket() {
    if (packetQueue.empty())
      return {};
    const Packet nextPacket = packetQueue.front();
    packetQueue.pop();
    uniquePackets.erase(nextPacket);
    ++processedPacketIndex[nextPacket.destination];
    return {nextPacket.source, nextPacket.destination, nextPacket.timestamp};
  }

  int getCount(int destination, int startTime, int endTime) {
    if (destinationTimestamps.find(destination) == destinationTimestamps.end())
      return 0;
    const vector<int>& timestamps = destinationTimestamps[destination];
    const int startIndex = processedPacketIndex[destination];
    const auto lowerBound = lower_bound(timestamps.begin() + startIndex,
                                        timestamps.end(), startTime);
    const auto upperBound =
        upper_bound(timestamps.begin() + startIndex, timestamps.end(), endTime);
    return upperBound - lowerBound;
  }

 private:
  const int memoryLimit;
  set<Packet> uniquePackets;
  queue<Packet> packetQueue;
  map<int, vector<int>> destinationTimestamps;
  map<int, int> processedPacketIndex;
};
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class Packet implements Comparable<Packet> {
  public int source;
  public int destination;
  public int timestamp;

  public Packet(int source, int destination, int timestamp) {
    this.source = source;
    this.destination = destination;
    this.timestamp = timestamp;
  }

  @Override
  public int compareTo(Packet other) {
    if (source != other.source)
      return Integer.compare(source, other.source);
    if (destination != other.destination)
      return Integer.compare(destination, other.destination);
    return Integer.compare(timestamp, other.timestamp);
  }

  @Override
  public boolean equals(Object o) {
    if (this == o)
      return true;
    if (o == null || getClass() != o.getClass())
      return false;
    Packet packet = (Packet) o;
    return source == packet.source && destination == packet.destination &&
        timestamp == packet.timestamp;
  }

  @Override
  public int hashCode() {
    return Objects.hash(source, destination, timestamp);
  }
}

class Router {
  public Router(int memoryLimit) {
    this.memoryLimit = memoryLimit;
  }

  public boolean addPacket(int source, int destination, int timestamp) {
    Packet packet = new Packet(source, destination, timestamp);
    if (uniquePackets.contains(packet))
      return false;
    if (packetQueue.size() == memoryLimit)
      forwardPacket();
    packetQueue.add(packet);
    uniquePackets.add(packet);
    destinationTimestamps.computeIfAbsent(destination, k -> new ArrayList<>()).add(timestamp);
    return true;
  }

  public List<Integer> forwardPacket() {
    if (packetQueue.isEmpty())
      return Collections.emptyList();
    Packet nextPacket = packetQueue.poll();
    uniquePackets.remove(nextPacket);
    processedPacketIndex.merge(nextPacket.destination, 1, Integer::sum);
    return Arrays.asList(nextPacket.source, nextPacket.destination, nextPacket.timestamp);
  }

  public int getCount(int destination, int startTime, int endTime) {
    if (!destinationTimestamps.containsKey(destination))
      return 0;
    List<Integer> timestamps = destinationTimestamps.get(destination);
    final int startIndex = processedPacketIndex.getOrDefault(destination, 0);
    final int lowerBoundIndex = firstGreaterEqual(timestamps, startIndex, startTime);
    final int upperBoundIndex = firstGreater(timestamps, lowerBoundIndex, endTime);
    return upperBoundIndex - lowerBoundIndex;
  }

  private final int memoryLimit;
  private final TreeSet<Packet> uniquePackets = new TreeSet<>();
  private final Queue<Packet> packetQueue = new LinkedList<>();
  private final Map<Integer, List<Integer>> destinationTimestamps = new HashMap<>();
  private final Map<Integer, Integer> processedPacketIndex = new HashMap<>();

  private int firstGreaterEqual(List<Integer> timestamps, int startIndex, int startTime) {
    int l = startIndex;
    int r = timestamps.size();
    while (l < r) {
      final int m = (l + r) / 2;
      if (timestamps.get(m) >= startTime)
        r = m;
      else
        l = m + 1;
    }
    return l;
  }

  private int firstGreater(List<Integer> timestamps, int startIndex, int endTime) {
    int l = startIndex;
    int r = timestamps.size();
    while (l < r) {
      final int m = (l + r) / 2;
      if (timestamps.get(m) > endTime)
        r = m;
      else
        l = m + 1;
    }
    return l;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from dataclasses import dataclass


@dataclass(frozen=True)
class Packet:
  source: int
  destination: int
  timestamp: int


class Router:
  def __init__(self, memoryLimit: int):
    self.memoryLimit = memoryLimit
    self.uniquePackets: set[Packet] = set()
    self.packetQueue: collections.deque[Packet] = collections.deque()
    self.destinationTimestamps = collections.defaultdict(list)
    self.processedPacketIndex = collections.Counter()

  def addPacket(self, source: int, destination: int, timestamp: int) -> bool:
    packet = Packet(source, destination, timestamp)
    if packet in self.uniquePackets:
      return False
    if len(self.packetQueue) == self.memoryLimit:
      self.forwardPacket()
    self.packetQueue.append(packet)
    self.uniquePackets.add(packet)
    if destination not in self.destinationTimestamps:
      self.destinationTimestamps[destination] = []
    self.destinationTimestamps[destination].append(timestamp)
    return True

  def forwardPacket(self) -> list[int]:
    if not self.packetQueue:
      return []
    nextPacket = self.packetQueue.popleft()
    self.uniquePackets.remove(nextPacket)
    self.processedPacketIndex[nextPacket.destination] += 1
    return [nextPacket.source, nextPacket.destination, nextPacket.timestamp]

  def getCount(self, destination: int, startTime: int, endTime: int) -> int:
    if destination not in self.destinationTimestamps:
      return 0
    timestamps = self.destinationTimestamps[destination]
    startIndex = self.processedPacketIndex.get(destination, 0)
    lowerBound = bisect.bisect_left(timestamps, startTime, startIndex)
    upperBound = bisect.bisect_right(timestamps, endTime, startIndex)
    return upperBound - lowerBound