• 《算法竞赛·快冲300题》每日一题:“直径点对”


    算法竞赛·快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。
    所有题目放在自建的OJ New Online Judge
    用C/C++、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。


    直径点对” ,链接: http://oj.ecustacm.cn/problem.php?id=1736

    题目描述

    【题目描述】 给你一个n个节点的树,编号为1到n。求存在多少对节点,使得u到v的距离等于这棵树的直径
       树的直径:树上最远的两个点的距离
       树上两点的距离:两点之间边的数量
       <1,2>和<2,1>属于两对节点
    【输入格式】 第一行为正整数n(n≤300000)。接下来n-1行,每行两个数字u和v,表示点u和点v之间存在边
    【输出格式】 输出一个数字表示答案。
    【输入样例】

    4
    1 2
    1 3
    1 4
    
    • 1
    • 2
    • 3
    • 4

    【输出样例】

    6
    
    • 1

    题解

       求树的直径有两种方法[ 《算法竞赛》,清华大学出版社,罗勇军、郭卫斌著,231页,“4.7.2 树的直径”。]:
       (1)做两次DFS,第一次求一个任意点的最远点s,第二次求s的最远点t,s和t之间的距离就是树的直径。
       (2)树形DP。算法不难,但是解释有点长,见《算法竞赛》233页的说明。请仔细理解如何用树形DP求树的直径。
       以上两种方法,算法复杂度都为O(n),即只对每个节点处理O(1)次。
       本题用这两种方法都能求解。下面用树形DP求树的直径,求直径的同时统计距离为直径的节点对数量。
    【重点】 树形DP。

    C++代码

       定义状态dp[],dp[u]表示从u出发的最长路径的长度,这条路径的终点是u的一个叶子节点。
       定义num[],num[u]表示从u出发的最长路径的数量。
       定义maxlen,表示直径的长度,k是经过直径的节点对数量。
       详细解释见代码的注释。如果仍然不能理解,可以把代码中与num[]和k有关的第8、20、22、23、26、28、29行删除,剩下的代码就是《算法竞赛》233页的树形DP求直径的模板。然后再加上num[]和k的代码并理解。

    #include
    using namespace std;
    typedef long long ll;
    const int N = 300010;
    vector<int>e[N];
    int dp[N];                   //dp[u]:从u出发的最长路径
    int num[N];                  //num[u]:从u出发的最长路径数量
    ll maxlen = 0, k = 0;        //直径的长度maxlen,经过直径的节点对数量k
    void dfs(int u, int fa){
        dp[u] = 0;
        num[u] = 1;
        for(auto v : e[u]){
            if(v == fa)  continue;
            dfs(v, u);                   //继续深入,回溯时带回算好的dp[v]
            int now = dp[v] + 1;         //从u出发,经过子节点v的最长路径
            if(now + dp[u] > maxlen){    //此时dp[u]是不经过v,而经过其他子节点的最长路径
                                         //now+dp[u]是经过u的最长路径
                maxlen = now + dp[u];    //更新maxlen为经过u的最长路径
                                         //此时u、v可能在树的直径上。比较所有的maxlen,最大的就是树的直径
                k = num[u] * num[v];  //计算k,这个k可能重新赋值
            }
            else if(now + dp[u] == maxlen)  //把此时的len看成树的直径,如果20行的k更新,这里也会重算
                k += num[u] * num[v];
            if(now > dp[u]){            //u经过v的路径更长
                dp[u] = now;            //更新dp[u]为经过v的路径
                num[u] = num[v];        //v更可能在直径上,把经过u的最长路径数量更新为经过v的数量
            }
            else if(now == dp[u])       //相等,这也是最长路径
                num[u] += num[v];
        }
    }
    int main(){
        int n;   scanf("%d", &n);
        for(int i = 1; i < n; i++){
            int u, v;    scanf("%d%d", &u, &v);
            e[u].push_back(v);  //加边
            e[v].push_back(u);
        }
        dfs(1, 0);
        cout << k * 2 << endl;        //按题意u-v和v-u不同,所以乘以2
        return 0;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    Java代码

    import java.util.*;
    import java.io.*;
     
    public class Main {
        static FastReader scanner = new FastReader();
        static int N = 300010;
        static ArrayList<Integer>[] e = new ArrayList[N];
        static int[] dp = new int[N]; // dp[u]:从u出发的最长路径
        static int[] num = new int[N]; // num[u]:从u出发的最长路径数量
        static long maxlen = 0, k = 0; // 直径的长度maxlen,经过直径的节点对数量k 
        public static void main(String[] args) throws IOException {         
            int n = scanner.nextInt();
            for (int i = 1; i <= n; i++)   e[i] = new ArrayList<>();        
            for (int i = 1; i < n; i++) {
                int u = scanner.nextInt();
                int v = scanner.nextInt();
                e[u].add(v); // 加边
                e[v].add(u);
            }
            dfs(1, 0);
            System.out.println(k * 2); // 按题意u-v和v-u不同,所以乘以2
        } 
        public static void dfs(int u, int fa) {
            dp[u] = 0;
            num[u] = 1;
            for (int v : e[u]) {
                if (v == fa)  continue;
                dfs(v, u); // 继续深入,回溯时带回算好的dp[v]
                int now = dp[v] + 1; // 从u出发,经过子节点v的最长路径
                if (now + dp[u] > maxlen) { // 此时dp[u]是不经过v,而经过其他子节点的最长路径
                    // now+dp[u]是经过u的最长路径
                    maxlen = now + dp[u]; // 更新maxlen为经过u的最长路径
                    // 此时u、v可能在树的直径上。比较所有的maxlen,最大的就是树的直径
                    k = num[u] * num[v]; // 计算k,这个k可能重新赋值
                } else if (now + dp[u] == maxlen) 
    // 把此时的len看成树的直径,如果34行的k更新,这里也会重算
                    k += num[u] * num[v];
                if (now > dp[u]) { // u经过v的路径更长
                    dp[u] = now; // 更新dp[u]为经过v的路径
                    // v更可能在直径上,把经过u的最长路径数量更新为经过v的数量
                    num[u] = num[v];
                } else if (now == dp[u]) // 相等,这也是最长路径
                    num[u] += num[v];
            }
        } 
        static class FastReader {
            BufferedReader br;
            StringTokenizer st; 
            public FastReader() { br = new BufferedReader(new InputStreamReader(System.in));  }
             String next() {
                while (st == null || !st.hasMoreElements()) {
                    try {st = new StringTokenizer(br.readLine());} 
                    catch (IOException e) { e.printStackTrace();}
                }
                return st.nextToken();
            } 
            int nextInt() {  return Integer.parseInt(next()); } 
            String nextLine() {
                String str = "";
                try { str = br.readLine();} 
                catch (IOException e) { e.printStackTrace(); }
                return str;
            }
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65

    Python代码

    #pypy
    from collections import defaultdict
    import sys
    sys.setrecursionlimit(300000)
     
    e = defaultdict(list)
    dp = [0] * 300010
    num = [0] * 300010
    maxlen = 0
    k = 0
     
    def dfs(u, fa):
        global maxlen, k
        dp[u] = 0
        num[u] = 1
        for v in e[u]:
            if v == fa:   continue
            dfs(v, u)
            now = dp[v] + 1
            if now + dp[u] > maxlen:
                maxlen = now + dp[u]
                k = num[u] * num[v]
            elif now + dp[u] == maxlen:  k += num[u] * num[v]
            if now > dp[u]:
                dp[u] = now
                num[u] = num[v]
            elif now == dp[u]:  num[u] += num[v]
     
    n = int(input())
    for i in range(1, n):
        u, v = map(int, input().split())
        e[u].append(v)
        e[v].append(u)
     
    dfs(1, 0)
    print(k * 2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
  • 相关阅读:
    图像增强方法资料汇总
    ArcGIS:如何迭代Shp文件所有要素并分别导出为Shp文件?
    零基础小白应该如何快速入门软件测试
    linux篇【11】:linux下的线程<后序>
    什么人适合学NPDP产品经理认证?
    Pandas与数据库交互详解
    GoF 23 备忘录模式
    java基本数据类型Char
    Mybatis连接数据库
    代码随想录——比较含退格的字符串
  • 原文地址:https://blog.csdn.net/weixin_43914593/article/details/132715121