QOJ.ac

QOJ

Type: Editorial

Status: Open

Posted by: jiangly

Posted at: 2025-12-14 07:19:36

Last updated: 2025-12-14 07:19:49

Back to Problem

题解

题意

给出一个环形序列,被分为 $m$ 段。有 $n$ 个国家,序列的第 $i$ 段属于国家 $o_i$。接下来有 $k$ 次事件,每次给环形序列上的一个区间加上一个正整数。每个国家有一个期望 $p_i$,求出每个国家在序列上所有位置的值的和到达 $p_i$ 的最早时间(或报告无法达到)。

限制

$1\le n,m,k\le 3\cdot10^5$。

题解

下面认为 $n,m,k$ 同阶,时间复杂度均以 $n$ 表示。

忽略环的条件,跨越首尾的事件可以拆成两个。

考虑二分每个国家的时间,每次二分后执行二分位置之前的所有事件,判断是否达到期望。然而这样的时间复杂度太高。但是我们可以将所有国家一起二分,这样每一轮只用完整地执行一次操作序列,时间复杂度 $O(n(\log n)^2)$。

用树状数组维护操作再结合一定的常数优化已经足以通过该题,但还有复杂度更优秀的做法。

我们考虑扫描序列,用线段树维护时间。对于第 $i$ 个操作 $([l_i,r_i),a_i)$,扫到 $l_i$ 时在位置 $i$ 加 $a_i$,扫到 $r_i$ 时减 $a_i$。这样直接用整体二分的方法并不能改善复杂度,但是在将线段树可持久化后把一个国家的所有位置一起在线段树上二分就能做到 $O(n\log n)$ 的时空复杂度。然而本题的空间限制很紧,这种做法并不能通过。

沿用之前整体二分和差分操作的方法,假设目前在区间 $[l_v,r_v)$ 内二分,在保持操作和序列位置有序的情况下,我们可以利用 two-pointers 在线性时间内求出所有国家在给定操作集合(即 $[l_v,\lfloor\frac{l_v+r_v}{2}\rfloor)$ 中的操作)中的权值 $\mathrm{val}_i$。如果 $\mathrm{val}_i\ge p_i$,向左递归;如果 $\mathrm{val}_i < p_i$,将 $p_i$ 减去 $\mathrm{val}_i$,向右递归。由于递归层数为 $O(\log n)$,该算法的时间复杂度为 $O(n\log n)$,空间复杂度为 $O(n)$。

代码

#include <cstdio>
#include <vector>
#include <algorithm>
#include <numeric>
#include <cctype>
constexpr int N = 300'000;
int n, m, q, qn;
int o[N], p[N], s[N], ans[N], tmp[N];
long long sum[N];
struct Query {
    int x, y, v;
    Query() {}
    Query(int x, int y, int v) : x(x), y(y), v(v) {}
    friend bool operator<(const Query &lhs, const Query &rhs) {
        return lhs.x < rhs.x;
    }
};
Query queries[3 * N], temp[3 * N];
char buf[1 << 22], *p1, *p2;
char get() {
    if (p1 == p2) {
        p1 = buf;
        p2 = buf + fread(buf, 1, sizeof(buf), stdin);
    }
    if (p1 == p2)
        return EOF;
    return *p1++;
}
int readInt() {
    int x = 0;
    char c = get();
    while (!std::isdigit(c))
        c = get();
    while (std::isdigit(c)) {
        x = 10 * x + c - '0';
        c = get();
    }
    return x;
}
void print(int x) {
    static char stk[20];
    int top = 0;
    while (x > 0) {
        stk[top++] = x % 10 + '0';
        x /= 10;
    }
    for (int i = top - 1; i >= 0; --i)
        *p2++ = stk[i];
    *p2++ = '\n';
}
void nie() {
    *p2++ = 'N';
    *p2++ = 'I';
    *p2++ = 'E';
    *p2++ = '\n';
}
void solve(int vl, int vr, int sl, int sr, int ql, int qr) {
    if (sl == sr)
        return;
    for (int i = sl; i < sr; ++i)
        sum[o[s[i]]] = 0;
    if (vr - vl == 1) {
        int qi = ql;
        long long curSum = 0;
        for (int i = sl; i < sr; ++i) {
            while (qi < qr && queries[qi].x <= s[i]) {
                curSum += queries[qi].v;
                ++qi;
            }
            if (sum[o[s[i]]] < p[o[s[i]]])
                sum[o[s[i]]] += curSum;
        }
        for (int i = sl; i < sr; ++i) {
            if (sum[o[s[i]]] >= p[o[s[i]]]) {
                ans[o[s[i]]] = vr;
            } else {
                ans[o[s[i]]] = -1;
            }
        }
        return;
    }
    int vm = (vl + vr) / 2;
    int qm = ql, nqr = qr;
    std::copy(queries + ql, queries + qr, temp + ql);
    for (int i = ql; i < qr; ++i) {
        if (temp[i].y < vm) {
            queries[qm++] = temp[i];
        } else {
            queries[--nqr] = temp[i];
        }
    }
    std::reverse(queries + qm, queries + qr);
    int qi = ql;
    long long curSum = 0;
    for (int i = sl; i < sr; ++i) {
        while (qi < qm && queries[qi].x <= s[i]) {
            curSum += queries[qi].v;
            ++qi;
        }
        if (sum[o[s[i]]] < p[o[s[i]]])
            sum[o[s[i]]] += curSum;
    }
    std::copy(s + sl, s + sr, tmp + sl);
    int sm = sl, nsr = sr;
    for (int i = sl; i < sr; ++i) {
        if (sum[o[tmp[i]]] >= p[o[tmp[i]]]) {
            s[sm++] = tmp[i];
        } else {
            s[--nsr] = tmp[i];
        }
    }
    std::reverse(s + sm, s + sr);
    for (int i = sm; i < sr; ++i) {
        if (sum[o[s[i]]] != -1) {
            p[o[s[i]]] -= sum[o[s[i]]];
            sum[o[s[i]]] = -1;
        }
    }
    solve(vl, vm, sl, sm, ql, qm);
    solve(vm, vr, sm, sr, qm, qr);
}
int main() {
    n = readInt();
    m = readInt();
    std::fill(ans, ans + n, -1);
    std::iota(s, s + m, 0);
    for (int i = 0; i < m; ++i)
        o[i] = readInt() - 1;
    for (int i = 0; i < n; ++i)
        p[i] = readInt();
    q = readInt();
    for (int i = 0; i < q; ++i) {
        int l, r, a;
        l = readInt() - 1;
        r = readInt();
        a = readInt();
        if (l < r) {
            queries[qn++] = Query(l, i, a);
            if (r != m)
                queries[qn++] = Query(r, i, -a);
        } else {
            queries[qn++] = Query(0, i, a);
            if (l != r) {
                queries[qn++] = Query(r, i, -a);
                queries[qn++] = Query(l, i, a);
            }
        }
    }
    std::sort(queries, queries + qn);
    solve(0, q, 0, m, 0, qn);
    p1 = p2 = buf;
    for (int i = 0; i < n; ++i) {
        if (ans[i] == -1) {
            nie();
        } else{
            print(ans[i]);
        }
    }
    fwrite(buf, 1, p2 - p1, stdout);
    return 0;
}

Comments

No comments yet.