描述 Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
输入格式 InputFormat
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
输出格式 OutputFormat
For each test case output the answer on a single line.
样例输入 SampleInput
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
样例输出 SampleOutput
8
代码 Code
点分治。
#include <stdio.h>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int inf = 0x7fffffff / 27.11;
const int maxn = 10005;
const int maxm = maxn * 2;
int i, j, t, n, m, l, r, k, z, y, x;
struct edge
{
int to, vl, nx;
} e[maxm];
int head[maxn], dis[maxn], a[maxn], son[maxn];;
bool vis[maxn];
int cnt, num, ans, len, siz, c;
void dfs(int u, int fa)
{
int i, t, v;
son[u] = 1;
for (t = 0, i = head[u]; i; i = e[i].nx)
{
v = e[i].to;
if (!vis[v] && v != fa)
{
dfs(v, u);
son[u] += son[v];
t = max(t, son[v]);
}
}
t = max(t, m - son[u]);
if (t < siz) c = u, siz = t;
}
int getcenter(int s)
{
c = 0; siz = inf;
dfs(s, -1);
return c;
}
inline void ins(int u, int v, int w)
{
e[++cnt] = (edge)
{
v, w, head[u]
}; head[u] = cnt;
e[++cnt] = (edge)
{
u, w, head[v]
}; head[v] = cnt;
}
void getarray(int u, int fa)
{
int i, v, w;
a[++len] = dis[u];
for (i = head[u]; i; i = e[i].nx)
{
v = e[i].to; w = e[i].vl;
if (!vis[v] && v != fa) dis[v] = dis[u] + w, getarray(v, u);
}
}
int calc(int u, int now)
{
int ans = 0;
dis[u] = now; len = 0;
getarray(u, -1);
sort(a + 1, a + len + 1);
l = 1; r = len;
while (l < r)
{
if (a[r] + a[l] <= k) ans += (r - l), l++;
else r--;
}
return ans;
}
void solve(int s)
{
int i, v, w;
ans += calc(s, 0);
vis[s] = true;
for (i = head[s]; i; i = e[i].nx)
{
v = e[i].to; w = e[i].vl;
if (!vis[v])
{
ans -= calc(v, w);
m = son[v];
r = getcenter(v);
solve(r);
}
}
}
int main()
{
while (scanf("%d%d", &n, &k) != EOF)
{
cnt = 0;
if (n == 0 && k == 0) break;
memset(head, 0, sizeof(head));
memset(vis, 0, sizeof(vis));
for (i = 1; i < n; i++) scanf("%d%d%d", &x, &y, &z), ins(x, y, z);
ans = 0;
m = n;
r = getcenter(1);
solve(r);
printf("%d\n", ans);
}
return 0;
}