题意
给定一棵 $N$ ($2\le N\le 2\cdot10^5$) 个点的树,第 $i$ 条边连接 $A_i$ 和 $B_i$,其中 $(A_i,B_i)$ 的边权为 $C_i$,$(B_i,A_i)$ 的边权为 $D_i$。接下来有 $Q$ ($1\le Q\le N$) 个询问,第 $j$ 次询问要求选择 $E_j$ 个点,最小化所有满足下列条件的有向边 $(U,V)$ 的边权和。
- 不存在一个被选择的点 $X$,使得 $X$ 和 $V$ 之间的边数小于 $X$ 和 $U$ 之间的边数。
题解
首先可以 $O(N)$ 树形 dp 求出每个点向外的所有边权和以及距离的最大值,并求出 $E_j=1$ 和 $E_j=2$ 时的答案。可以证明,对于任意的 $E_j=k$ ($2\le k\le N-1$) 时的最优解 $S_k$,存在一组 $E_j=k+1$ 时的最优解 $S_{k+1}$ 使得 $S_k\subset S_{k+1}$。简单贪心即可。
时间复杂度:$O(N\log N)$。
代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
#include <functional>
#include <map>
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n;
std::cin >> n;
std::vector<std::vector<int>> e(n);
std::vector<int> tip, len;
auto addEdge = [&](int u, int v, int l) {
e[u].push_back(tip.size());
tip.push_back(v);
len.push_back(l);
};
for (int i = 0; i < n - 1; ++i) {
int a, b, c, d;
std::cin >> a >> b >> c >> d;
--a;
--b;
addEdge(a, b, c);
addEdge(b, a, d);
}
std::vector<std::vector<long long>> max(n, std::vector<long long>(2));
std::vector<long long> sum(n);
auto update = [&](auto &v, auto x) {
if (x > v[0])
std::swap(x, v[0]);
v[1] = std::max(v[1], x);
};
std::function<void(int, int)> dfs1 = [&](int u, int p) {
for (auto i : e[u]) {
int v = tip[i];
if (v == p)
continue;
sum[0] += len[i];
dfs1(v, u);
update(max[u], max[v][0] + len[i]);
}
};
dfs1(0, -1);
std::function<void(int, int)> dfs2 = [&](int u, int p) {
for (auto i : e[u]) {
int v = tip[i];
if (v == p)
continue;
update(max[v], max[u][max[v][0] + len[i] == max[u][0]] + len[i ^ 1]);
sum[v] = sum[u] - len[i] + len[i ^ 1];
dfs2(v, u);
}
};
dfs2(0, -1);
std::vector<long long> ans(n, std::numeric_limits<long long>::max());
int s = -1;
for (int i = 0; i < n; ++i) {
ans[0] = std::min(ans[0], sum[i]);
if (sum[i] - max[i][0] < ans[1]) {
ans[1] = sum[i] - max[i][0];
s = i;
}
}
std::vector<long long> f(n);
std::function<int(int, int)> dfs3 = [&](int u, int p) {
int x = u;
for (auto i : e[u]) {
int v = tip[i];
if (v == p)
continue;
v = dfs3(v, u);
f[v] += len[i];
if (f[v] > f[x])
x = v;
}
return x;
};
dfs3(s, -1);
std::sort(f.begin(), f.end(), std::greater<>());
for (int i = 1; i < n - 1; ++i)
ans[i + 1] = ans[i] - f[i];
int q;
std::cin >> q;
while (q--) {
int x;
std::cin >> x;
--x;
std::cout << ans[x] << "\n";
}
return 0;
}