Skip to content

点分治

转自Menci的博客:点分治学习笔记 点分治是用来解决树上路径问题的一种方法。

在解决树上路径问题时,我们可以选取一点为根,将树转化为有根树,然后考虑经过根的所有路径(有时将两条从根出发的路径连接为一条)。统计完这些路径的答案后,将根节点标记为删除,对剩下的若干棵树进行同样的操作。

如图,我们可以先考虑经过节点 $$1$$ 的路径,之后将节点 $$1$$ 标记为删除,此时可以认为考虑过的路径均已被删除。继续对其它子树做相同处理即可。

每次确认一个根节点后,共有 $$n$$ 条需要考虑的路径($$n$$ 为当前子树大小)。上图中将 $$1$$ 删除后,剩下左侧的子树较大,和原树大小相当,继续处理这棵子树时仍然需要与前一过程相当的时间。

最严重的情况,当整棵树是一条链时,每次需要考虑的路径数量是 $$O(n)$$ 级别的,如果每条路径需要常数时间进行统计,则总时间复杂度为 $$O(n ^ 2)$$。而对于形态随机的树,则远远小于这个级别。

如果我们选择 $$5$$ 作为这棵树的根节点,情况会好很多 —— 删除 $$5$$ 后剩余的最大一棵子树的大小比删除 $$1$$ 时要小。这说明「科学地」选择点作为根节点可以有效的降低复杂度。

重心

我们定义一棵树的重心为以该点为根时最大子树最小的点。

性质:以重心为根,任意一棵子树的大小都不超过整棵树大小的一半。

证明:从树上任取一点,以它为根,如果最大的一棵子树大小不超过整棵树大小的一半,则它为重心。否则选择最大子树的根节点,继续这个过程,最终会得到一个点,它满足重心的性质,从这个点向任何方向走,最多有一个点同样满足重心的性质。 注意不会出现来回走,两个点都不满足性质的情况。假设有,则删掉这两个点后,剩下的两棵树的大小都至少为 $$n \over 2$$,整棵树至少有 $$n + 2$$ 个点,不成立。

求重心可以用一次 DFS 完成 —— 任选一个点为根做 DFS,记录每个节点的大小 ,记录最大子节点子树的大小 。因为要同时考虑某个点的祖先(以这个点为根时这些点为它的一棵子树),所以使$$ \max { \max(i),\ n - siz(i) }$$ 最小的 $$i$$ 即为重心。

如果在点分治时每次使用重心为根,则最大的子树大小不会超过原树的二分之一,考虑到处理较小子树的代价远小于最大子树,若每个节点需要常数时间,根据主定理有

T(n)=2T(n/2)+O(n)=O(nlogn)

如图,蓝色点为第一次选取的重心,删除蓝色点后,剩余几棵子树的重心为红色点,再向下一层的重心为黄色点,最后剩下一个白色点

POJ1741 Tree(点分治,带详解)

本题是点分治的模板题

题意是给了你一棵树,然后给你一个k,问你在这棵树上有多少对点之间的距离小于等于k

我们要用点分治来解决这道题。

我们可以先想一个小问题,如何求出有多少对经过根的路径且距离小于等于k的节点。我们可以把这一棵树上的每一个点到树根的距离计算出来,然后从小到大排序,利用二分如果出现两个点深度的和小于等于k(deep[l]+deep[r]<=k),那么就有r-l个点符合要求,可以加到答案中去.

点分治就是在解决树上路径问题时,我们可以选取一点为根,将树转化为有根树,然后考虑经过根的所有路径(有时将两条从根出发的路径连接为一条)。统计完这些路径的答案后,将根节点标记为删除,对剩下的若干棵树进行同样的操作。

点分治的时候有一个重要的操作,就是求出树的重心,树的重心指的是以该点为根时这棵树的最大子树最小。求得时候我们可以定义:

  • siz[u]:以u为根的子树(包括自己)的节点数量
  • f[u]:以u为根的最大子树

这两个数组我们可以通过一遍dfs来求出,并且f[i]最大的点就是我们要找的树的重心.

之后我们就可以通过点分治来递归的解决这个问题,计算完一个根就把这个根删去,最后累加完结果.

但是,点分治要求处理的路径是经过root,所以如果一条路径是在同一个子树之内的就不符合要求,所以还要对子树dfs一下,然后去重

具体实现,详见代码.

参考博客:【点分治】的学习笔记和众多例题

cpp
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
#define mem(a, b) memset(a, b, sizeof(a))
const int inf = 1e9 + 10;
const int N = 10010 + 10;
int root, n, k, ans, sum;
int siz[N], f[N]; //siz[i]表示以i为根的节点数量,f[i]表示以i为根的最大子树大小
int first[N], tot;
int vis[N]; //标记这个点有没有被删过
int d[N], deep[N];
struct edge
{
    int v, w, next;
} e[N * 2];
void add_edge(int u, int v, int w)
{
    e[tot].v = v, e[tot].w = w;
    e[tot].next = first[u];
    first[u] = tot++;
}
void getroot(int u, int fa)
{
    siz[u] = 1;
    f[u] = 0;
    for (int i = first[u]; ~i; i = e[i].next)
    {
        int v = e[i].v;
        if (v == fa || vis[v])
            continue;
        getroot(v, u);
        siz[u] += siz[v];
        f[u] = max(f[u], siz[v]);
    }
    f[u] = max(f[u], sum - siz[u]); //以u的父节点为根的子树
    if (f[u] < f[root])
        root = u;
}
void getdeep(int u, int fa)
{
    deep[++deep[0]] = d[u];
    for (int i = first[u]; ~i; i = e[i].next)
    {
        int v = e[i].v, w = e[i].w;
        if (v != fa && !vis[v])
        {
            d[v] = d[u] + w;
            getdeep(v, u);
        }
    }
}
int cal(int u, int cost)
{
    d[u] = cost;
    deep[0] = 0;                        //deep[0]表示深度
    getdeep(u, 0);                      //处理以u为根的树深度
    sort(deep + 1, deep + deep[0] + 1); //对所有的深度进行排序
    int l = 1, r = deep[0], res = 0;
    while (l < r)
    {
        if (deep[l] + deep[r] <= k) //判断是否符合条件.
        {
            res += r - l;
            l++;
        }
        else
            r--;
    }
    return res;
}

void solve(int u)
{
    ans += cal(u, 0); //处理以u点为根的树
    vis[u] = 1;
    for (int i = first[u]; ~i; i = e[i].next)
    {
        int v = e[i].v, w = e[i].w;
        if (!vis[v])
        {
            ans -= cal(v, w); //减去同一个子树内不满足要求的
            sum = siz[v];
            root = 0;
            getroot(v, 0);
            solve(root);
        }
    }
}
void init()
{
    ans = root = tot = 0;
    mem(first, -1);
    mem(vis, 0);
}
int main()
{
    //freopen("in.txt", "r", stdin);
    int u, v, w;
    while (scanf("%d%d", &n, &k) && (n || k))
    {
        init();
        for (int i = 1; i <= n - 1; i++)
        {
            scanf("%d%d%d", &u, &v, &w);
            add_edge(u, v, w);
            add_edge(v, u, w);
        }
        f[0] = inf;
        sum = n;
        getroot(1, 0);
        solve(root);
        printf("%d\n", ans);
    }
    return 0;
}

BZOJ2152 聪聪可可(点分治)

利用点分治,要求出两个点路径上的和是3的倍数的种类数。

直接找到重心V,然后从V出发,搜索与V相邻的点,计算边长的余数分别是0,1,2的情况数,用t[0],t[1],t[2]分别表示。 显然答案就是 t[1]*t[2]*2+t[0]*t[0]

cpp
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
#define mem(a, b) memset(a, b, sizeof(a))
const int inf = 1e9 + 10;
const int N = 1e5 + 10;
int root, n, ans, sum;
int siz[N], f[N];
int first[N], tot;
int vis[N];
int d[N], t[5];
struct edge
{
    int v, w, next;
} e[N * 2];
void add_edge(int u, int v, int w)
{
    e[tot].v = v, e[tot].w = w;
    e[tot].next = first[u];
    first[u] = tot++;
}
void getroot(int u, int fa)
{
    siz[u] = 1;
    f[u] = 0;
    for (int i = first[u]; ~i; i = e[i].next)
    {
        int v = e[i].v;
        if (v != fa && !vis[v])
        {
            getroot(v, u);
            siz[u] += siz[v];
            f[u] = max(f[u], siz[v]);
        }
    }
    f[u] = max(f[u], sum - siz[u]);
    if (f[u] < f[root])
        root = u;
}
void getdeep(int u, int fa)
{

    t[d[u]]++;
    for (int i = first[u]; ~i; i = e[i].next)
    {
        int v = e[i].v, w = e[i].w;
        if (v != fa && !vis[v])
        {
            d[v] = (d[u] + w) % 3;
            getdeep(v, u);
        }
    }
}
int cal(int u, int w)
{
    t[0] = t[1] = t[2] = 0;
    d[u] = w;
    getdeep(u, 0);
    return 2 * t[1] * t[2] + t[0] * t[0];
}
void solve(int u)
{
    ans += cal(u, 0);
    vis[u] = 1;
    for (int i = first[u]; ~i; i = e[i].next)
    {
        int v = e[i].v, w = e[i].w;
        if (!vis[v])
        {
            ans -= cal(v, w);
            sum = siz[v];
            root = 0;
            getroot(v, 0);
            solve(root);
        }
    }
}
void init()
{
    ans = root = tot = 0;
    f[0] = n;
    sum = n;
    mem(first, -1);
    mem(vis, 0);
}
int main()
{
    //freopen("in.txt", "r", stdin);
    int u, v, w;
    scanf("%d", &n);
    init();
    for (int i = 1; i <= n - 1; i++)
    {
        scanf("%d%d%d", &u, &v, &w);
        w %= 3;
        add_edge(u, v, w);
        add_edge(v, u, w);
    }
    getroot(1, 0);
    solve(root);
    int x = __gcd(ans, n * n);
    printf("%d/%d\n", ans / x, n * n / x);
    return 0;
}

BZOJ1316 树上的询问(点分治)

一棵n个点的带权有根树,有p个询问,每次询问树中是否存在一条长度为Len的路径,如果是,输出Yes否输出No

利用点分治,首先把树的重心找出来之后,然后以这一点为根,递归的解决,点分治处理的是经过根节点的路径,所以在计算完当前节点之后要去重(儿子),每次把处理的所有深度排序,然后遍历每一个深度值,看一下存在k-x的有多少个累加这个值就是答案,利用二分实现。先把所有的查询存下来,然后离线处理。

cpp
#include <bits/stdc++.h>
using namespace std;
#define mem(a, b) memset(a, b, sizeof(a))
const int N = 1e4 + 10;
int first[N], tot, n, p;
int q[N], siz[N], f[N], d[N], deep[N], vis[N];
int sum, root;
struct edge
{
    int v, w, next;
} e[N * 2];
void add_edge(int u, int v, int w)
{
    e[tot].v = v, e[tot].w = w;
    e[tot].next = first[u];
    first[u] = tot++;
}
void init()
{
    mem(first, -1);
    tot = 0;
    sum = f[0] = n;
    root = 0;
}
void getroot(int u, int fa)
{
    siz[u] = 1, f[u] = 0;
    for (int i = first[u]; ~i; i = e[i].next)
    {
        int v = e[i].v;
        if (!vis[v] && v != fa)
        {
            getroot(v, u);
            siz[u] += siz[v];
            f[u] = max(f[u], siz[v]);
        }
    }
    f[u] = max(f[u], sum - siz[u]);
    if (f[u] < f[root])
        root = u;
}
void getdeep(int u, int fa)
{
    deep[++deep[0]] = d[u];
    for (int i = first[u]; ~i; i = e[i].next)
    {
        int v = e[i].v, w = e[i].w;
        if (!vis[v] && v != fa)
        {
            d[v] = d[u] + w;
            getdeep(v, u);
        }
    }
}
int findl(int l, int r, int k)
{
    int ans = 0;
    while (l <= r)
    {
        int mid = (l + r) >> 1;
        if (deep[mid] == k)
        {
            ans = mid;
            r = mid - 1;
        }
        else if (deep[mid] < k)
            l = mid + 1;
        else
            r = mid - 1;
    }
    return ans;
}
int findr(int l, int r, int k)
{
    int ans = -1;
    while (l <= r)
    {
        int mid = (l + r) >> 1;
        if (deep[mid] == k)
        {
            ans = mid;
            l = mid + 1;
        }
        else if (deep[mid] < k)
            l = mid + 1;
        else
            r = mid - 1;
    }
    return ans;
}

int cal(int u, int cost, int k)
{
    d[u] = cost;
    deep[0] = 0;
    getdeep(u, 0);
    sort(deep + 1, deep + deep[0] + 1);
    int t = 0;
    for (int i = 1; i <= deep[0]; i++)
    {
        if (deep[i] + deep[i] > k)
            break;
        int l = findl(i, deep[0], k - deep[i]);
        int r = findr(i, deep[0], k - deep[i]);
        /*二分的部分也可以用lower_bound和upper_bound实现
        int l = lower_bound(deep + 1, deep + deep[0] + 1, k - deep[i]) - deep;
        int r = upper_bound(deep + 1, deep + deep[0] + 1, k - deep[i]) - deep - 1;
        */
        t += r - l + 1;
    }
    return t;
}
int ans[110];
void solve(int u)
{
    for (int i = 1; i <= p; i++)
        ans[i] += cal(u, 0, q[i]);
    vis[u] = 1;
    for (int i = first[u]; ~i; i = e[i].next)
    {
        int v = e[i].v, w = e[i].w;
        if (!vis[v])
        {
            for (int j = 1; j <= p; j++)
                ans[j] -= cal(v, w, q[j]);
            sum = siz[v];
            root = 0;
            getroot(v, 0);
            solve(root);
        }
    }
}
int main()
{
    // freopen("in.txt", "r", stdin);
    int u, v, w;
    scanf("%d%d", &n, &p);
    init();
    for (int i = 1; i <= n - 1; i++)
    {
        scanf("%d%d%d", &u, &v, &w);
        add_edge(u, v, w);
        add_edge(v, u, w);
    }
    for (int i = 1; i <= p; i++)
        scanf("%d", &q[i]);
    getroot(1, 0);
    solve(root);
    for (int i = 1; i <= p; i++)
        puts(ans[i] ? "Yes" : "No");
    return 0;
}

基于 MIT 协议发布 · 使用 VitePress 构建