题意
有 $N$ ($3\le N\le 2\cdot10^5$) 块蛋糕,第 $i$ 块的蛋糕的价值为 $V_i$,颜色为 $C_i$。选择 $M$ ($3\le M\le N$) 块蛋糕排成一个环 $k_1,k_2,\ldots,k_M$,最大化 $$ \sum_{j=1}^{M}V_{k_j}-\sum_{j=1}^M|C_{k_j}-C_{k_j\bmod M + 1}| $$
题解
显然对于给定蛋糕集合的最大代价为 $\sum V_j+2\min \{C_j\}-2\max\{C_j\}$。将蛋糕按 $C_i$ 排序,枚举右端点,可以发现最优的左端点是不减的,因此可以用决策单调性优化。用主席树优化计算区间前 $M$ 大的 $V_i$ 的和即可。
时间复杂度:$O(n(\log n)^2)$。
代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
#include <functional>
struct Node {
Node *lc, *rc;
long long sum;
int cnt;
};
Node *null = new Node;
std::vector<int> values;
Node *insert(Node *p, int l, int r, int v) {
Node *q = new Node;
if (r - l == 1) {
q -> lc = q -> rc = null;
q -> cnt = p -> cnt + 1;
q -> sum = p -> sum + values[v];
} else {
int m = (l + r) / 2;
if (v < m) {
q -> lc = insert(p -> lc, l, m, v);
q -> rc = p -> rc;
} else {
q -> lc = p -> lc;
q -> rc = insert(p -> rc, m, r, v);
}
q -> cnt = q -> lc -> cnt + q -> rc -> cnt;
q -> sum = q -> lc -> sum + q -> rc -> sum;
}
return q;
}
long long rangeQuery(Node *p, Node *q, int l, int r, int k) {
if (r - l == 1) {
return 1ll * k * values[l];
} else {
int m = (l + r) / 2;
int t = q -> rc -> cnt - p -> rc -> cnt;
if (k <= t) {
return rangeQuery(p -> rc, q -> rc, m, r, k);
} else {
return q -> rc -> sum - p -> rc -> sum + rangeQuery(p -> lc, q -> lc, l, m, k - t);
}
}
}
int main() {
null -> lc = null -> rc = null;
null -> sum = null -> cnt = 0;
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n, m;
std::cin >> n >> m;
std::vector<int> v(n), c(n);
for (int i = 0; i < n; ++i)
std::cin >> v[i] >> c[i];
std::vector<int> p(n);
std::iota(p.begin(), p.end(), 0);
std::sort(p.begin(), p.end(), [&](int i, int j) {
return c[i] < c[j];
});
values = v;
std::sort(values.begin(), values.end());
values.erase(std::unique(values.begin(), values.end()), values.end());
for (int i = 0; i < n; ++i)
v[i] = std::lower_bound(values.begin(), values.end(), v[i]) - values.begin();
long long ans = std::numeric_limits<long long>::min();
std::vector<Node *> root(n + 1);
root[0] = null;
for (int i = 0; i < n; ++i)
root[i + 1] = insert(root[i], 0, values.size(), v[p[i]]);
auto get = [&](int l, int r) {
return rangeQuery(root[l], root[r + 1], 0, values.size(), m) + 2 * c[p[l]] - 2 * c[p[r]];
};
std::function<void(int, int, int, int)> solve = [&](int l, int r, int x, int y) {
if (r - l == 1) {
for (int i = x; i < y && l - i + 1 >= m; ++i)
ans = std::max(ans, get(i, l));
} else {
int mid = (l + r) / 2;
int k = -1;
long long res = std::numeric_limits<long long>::min();
for (int i = x; i < y && mid - i + 1 >= m; ++i) {
long long c = get(i, mid);
if (c >= res) {
k = i;
res = c;
}
}
solve(l, mid, x, k + 1);
solve(mid, r, k, y);
}
};
solve(m - 1, n, 0, n);
std::cout << ans << "\n";
return 0;
}