二分 - GESP C++ 五级

发表于
更新于
28 2.3~3.0 分钟 1033

引入

如何查找一个有序数组中某个元素的位置呢?你当然可以从头到尾一个一个对比,但这样会耗费大量时间复杂度,如果数组的长度为 nn,那么这样做的最坏时间复杂度为 O(n)O(n)。由于数组是有序的,我们可以每次将数组分成一半,判断元素再哪一半里,以此往复,直至查找到元素为止。这种方法利用了分治的思想,叫做二分法。

二分查找

二分查找的核心思想是将当前查找区间的中间元素与要查找的元素进行比较,以得出新的查找区间。下面给出一种二分查找的模板代码:

// 省略,N为数组长度
int arr[N];

int binary_search(int n) {
    int l = 0, r = N - 1;
    while (l <= r) {
        int mid = l + ((r - l) >> 1); // 防止数值溢出。位运算速度较算术运算更快
        if (arr[mid] < n) 
            l = mid + 1;
        else if (arr[mid] > n) 
            r = mid - 1
        else 
            return mid
    }
    return -1 // 未查找到数据返回-1
}

在这段代码中,我们判断arr的第mid 项与n 的数量关系。如果小于,我们就去查找右半段;如果大于,我们就去查找左半段;如果等于,我们直接返回查找到的mid 。如果 r 大于l 了我们仍然没有查找到n ,那就说明数组内不存在n ,这里我们返回-1 。二分查找的时间复杂度为 O(logn)O(\log n)

二分查找函数

C++ 的 STL 库提供了两个二分查找函数:lower_boundupper_bound (其实并不止这两个,只是这两个在信息学竞赛中比较常用)。这两个函数包含在 <algorithm> 头文件中。

lower_bound 函数查找首个不小于给定值的元素的位置,返回地址。使用格式为 lower_bound(首元素地址, 尾元素地址, 给定值)

upper_bound 函数查找首个大于给定值的元素的位置,返回地址。使用格式为 upper_bound(首元素地址, 尾元素地址, 给定值)

二分答案

对于一些算法题目,我们可以通过枚举的方式来求解。如果答案与某个判断条件满足单调性,就可以使用二分,只需要将枚举换为二分即可。这种利用二分查找答案的方法叫作二分答案。

二分答案是一类题目,不存在通用的解题方法,需要结合具体题目分析。不过,下面的代码给出了一个常见的二分答案模板,不保证通用:

// 省略

bool check(int x) {
    // 检查答案是否满足条件(不一定要完全相等,只需要大于等于/小于等于即可)
}

int main() {
    
    // 省略
    
    int l  = 0, r = 1e8; // 这里的l和r要根据实际题目的答案范围调整
    while (l <= r) {  // 也有可能是l < r,依具体题目
        int mid = (l + r) >> 1; // 括号内也有可能是l + r + 1,依具体题目
        if (check(mid)) 
            l = mid; // 也有可能是mid - 1,依具体题目
        else
            r = mid + 1 // 也有可能是mid,依具体题目
    }
    // 这里的答案为l,在一些时候也可能为r,依具体题目

    // 省略

}

例题

洛谷 P1873:砍树

题面

P1873 [COCI 2011/2012 #5] EKO / 砍树 - 洛谷

分析

设最高的树的高度为hh,那么答案就在[1,h][1, h]之间,所以可以用枚举解决。但这样时间复杂度过高。我们很容易发现高度HH越高,砍伐的树木越少,满足单调性,因此这道题可以用二分答案解决。

参考答案

#include <iostream>
using namespace std;

const int N = 1e6+10;
long long a[N], n, m, l, r;

bool check(long long x) {
    long long sum = 0;
    for (int i=1; i<=n; i++) {
        if (a[i] > x) {
            sum += a[i] - x;
            if (sum >= m) return 1;
        }
    }
    return 0;
}

int main() {

    cin >> n >> m;

    for (int i=1; i<=n; i++) {
        cin >> a[i];
        r = max(r, a[i]);
    }

    while (l < r) {
        long long mid = (l + r + 1) / 2;
        if (check(mid)) l = mid;
        else r = mid - 1;
    }

    cout << l;

    return 0;
}