問題描述

你得到一個長度為 $n$ 的整數陣列

從左到右計算每個長度為 $k$ 的視窗的中位數並輸出

練習題

連結:https://cses.fi/problemset/task/1076

這題主要問題是視窗滑動時,如何對加入跟捨去判斷中位數如何變動,如果我們每次都從排序好的結構去找第 $(n-1) / 2$ 個,這樣不夠效率。

可以思考一下如果加入跟捨去的數字在排序好的結構上的關係:

  1. 加入在中位數前,捨去在中位數後,中位數往前一個

  2. 加入在中位數前,捨去在中位數前,中位數不變

  3. 加入在中位數後,捨去在中位數後,中位數不變

  4. 加入在中位數後,捨去在中位數前,中位數往後一個

如果是插入跟中位數相同的數字,根據 multiset 特性,會被放在相同數字的最後,也就是放在中位數後

根據以上關係,我們可以分別整理出插入跟捨去對中位數的改變:

  • 若插入的值小於中位數,中位數往前一個

  • 若捨去的值小於等於中位數,中位數往後一個

方法一完整程式碼

#include <bits/stdc++.h>
using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, k;
    cin >> n >> k;
    vector<int> a(n);
    for (int& x : a) cin >> x;

    multiset<int> window;
    auto median = window.begin();

    for (int i = 0; i < k; ++i) {
        window.insert(a[i]);
    }

    median = next(window.begin(), (k - 1) / 2);
    cout << *median;

    for (int i = k; i < n; ++i) {
        int out = a[i - k], in = a[i];

        window.insert(in);
        if (in < *median) median--;
        if (out <= *median) median++;
        window.erase(window.find(out));

        cout << ' ' << *median;
    }
    cout << '\n';
}
  1. 維護兩個 multiset (lo, hi)

    • lo:存「較小的一半」,且 一定包含中位數

    • hi:存「較大的一半」

  2. 平衡規則 令 |S| 表集合大小:

    保持 |lo| = |hi|,或 |lo| = |hi| + 1

    也就是 lo 大小永遠不會比 hi 小,而且最多只多 1

  3. 取得中位數 因為 lo 一定 ≥ hi,且多的那一個必定在 lo

    所以中位數永遠是 *prev(lo.end())(lo 中的最大值)

方法二完整程式碼

#include <bits/stdc++.h>
using namespace std;

multiset<long long> lo, hi;

void rebalance() {
    while (lo.size() < hi.size()) {
        lo.insert(*hi.begin());
        hi.erase(hi.begin());
    }
    while (lo.size() > hi.size() + 1) {
        auto it = prev(lo.end());
        hi.insert(*it);
        lo.erase(it);
    }
}

void add(long long x) {
    if (lo.empty() || x <= *prev(lo.end()))
        lo.insert(x);
    else
        hi.insert(x);
    rebalance();
}

void remove(long long x) {
    auto it = lo.find(x);
    if (it != lo.end())
        lo.erase(it);
    else
        hi.erase(hi.find(x));
    rebalance();
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, k;
    if (!(cin >> n >> k)) return 0;
    vector<long long> a(n);
    for (auto& x : a) cin >> x;

    for (int i = 0; i < k; ++i) add(a[i]);
    cout << *prev(lo.end());
    for (int i = k; i < n; ++i) {
        add(a[i]);
        remove(a[i - k]);
        cout << ' ' << *prev(lo.end());
    }
    cout << '\n';
    return 0;
}