class Solution:
def findMaximumElegance(self, items: list[list[int]], k: int) -> int:
ans = 0
totalProfit = 0
seenCategories = set()
decreasingDuplicateProfits = []
items.sort(reverse=True)
for i in range(k):
profit, category = items[i]
totalProfit += profit
if category in seenCategories:
decreasingDuplicateProfits.append(profit)
else:
seenCategories.add(category)
ans = totalProfit + len(seenCategories)**2
for i in range(k, len(items)):
profit, category = items[i]
if category not in seenCategories and decreasingDuplicateProfits:
# If this is a new category we haven't seen before, it's worth
# considering taking it and replacing the one with the least profit
# since it will increase the distinct_categories and potentially result
# in a larger total_profit + distinct_categories^2.
totalProfit -= decreasingDuplicateProfits.pop()
totalProfit += profit
seenCategories.add(category)
ans = max(ans, totalProfit + len(seenCategories)**2)
return ans