[HDU4812] D Tree

描述 Description

There is a skyscraping tree standing on the playground of Nanjing University of Science and Technology. On each branch of the tree is an integer (The tree can be treated as a connected graph with N vertices, while each branch can be treated as a vertex). Today the students under the tree are considering a problem: Can we find such a chain on the tree so that the multiplication of all integers on the chain (mod 106 + 3) equals to K?

Can you help them in solving this problem?

输入格式 InputFormat

There are several test cases, please process till EOF.
Each test case starts with a line containing two integers N(1 <= N <= 105) and K(0 <=K < 106 + 3). The following line contains n numbers vi(1 <= vi < 106 + 3), where vi indicates the integer on vertex i. Then follows N - 1 lines. Each line contains two integers x and y, representing an undirected edge between vertex x and vertex y.

输出格式 OutputFormat

For each test case, print a single line containing two integers a and b (where a < b), representing the two endpoints of the chain. If multiply solutions exist, please print the lexicographically smallest one. In case no solution exists, print “No solution”(without quotes) instead.
For more information, please refer to the Sample Output below.

样例输入 SampleInput

5 60
2 5 2 3 3
1 2
1 3
2 4
2 5
5 2
2 5 2 3 3
1 2
1 3
2 4
2 5

样例输出 SampleOutput

3 4
No solution


Hdu 4812


代码 Code

点分治 + 逆元 + Hash(其实可以不用)。各种 Debug 后自己都不认识了。

#include <stdio.h>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
#pragma comment(linker,"/STACK:102400000,102400000")
using namespace std;
const int inf = 0x7fffffff / 27.11;
const int maxn = 100005;
const int maxm = 2 * maxn;
const int mod = 1000003;
int i, j, t, n, m, l, r, k, z, y, x;
struct edge
{
    int to, nx;
} e[maxm];
int head[maxn], son[maxn], val[maxn];
int mp[mod + 5], use[mod + 5];
bool vis[maxn];
int cnt, num, ans1, ans2, a, b, c, siz;
long long inv[mod + 5];
inline void ins(int u, int v)
{
    e[++cnt].to = v; e[cnt].nx = head[u];
    head[u] = cnt;
}
int ms[maxn], pos;
void dfs(int u, int fa)
{
    int i, v;
    son[u] = 1;
    ms[u] = 0;
    for (i = head[u]; i; i = e[i].nx)
    {
        v = e[i].to;
        if (v == fa || vis[v]) continue;
        dfs(v, u);
        son[u] += son[v];
        ms[u] = max(ms[u], son[v]);
    }
}
//================
void findroot(int u, int fa, int all)
{
    int i, v;
    ms[u] = max(ms[u], all - son[u]);
    if (ms[u] < ms[pos]) pos = u;
    for (i = head[u]; i; i = e[i].nx)
    {
        v = e[i].to;
        if (v == fa || vis[v]) continue;
        findroot(v, u, all);
    }
}
int getroot(int u)
{
    dfs(u, -1);
    pos = u;
    findroot(u, -1, son[u]);
    return pos;
}
//=============
void calc(int u, int fa, int t)
{
    t = (long long)t * val[u] % mod;
    int i, v, c = (long long)k * inv[t] % mod;
    if (mp[c])
    {
        a = u; b = mp[c];
        if (a > b) swap(a, b);
        if (a < ans1) ans1 = a, ans2 = b;
        else if (a == ans1 && b < ans2) ans2 = b;
    }
    for (i = head[u]; i; i = e[i].nx)
    {
        v = e[i].to;
        if (!vis[v] && v != fa) calc(v, u, t);
    }
}
void update(int u, int fa, int t)
{
    t = (long long)t * val[u] % mod;
    int i, v;
    if (!mp[t]) use[num++] = t, mp[t] = u;
    else mp[t] = min(mp[t], u);
    for (i = head[u]; i; i = e[i].nx)
    {
        v = e[i].to;
        if (!vis[v] && v != fa) update(v, u, t);
    }
}
//=============
void solve(int s)
{
    int i, v;
    mp[val[s]] = s;
    vis[s] = true; num = 0;
    for (i = head[s]; i; i = e[i].nx)
    {
        v = e[i].to;
        if (!vis[v])
        {
            calc(v, s, 1);
            update(v, s, val[s]);
        }
    }
    for (i = 0; i < num; i++) mp[use[i]] = 0;
    mp[val[s]] = 0;
    for (i = head[s]; i; i = e[i].nx)
    {
        v = e[i].to;
        if (!vis[v])
        {
            r = getroot(v);
            solve(r);
        }
    }
}
//=============
int main()
{
    inv[1] = 1;
    for (i = 2; i < mod; i++) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
    while (scanf("%d%d", &n, &k) != EOF)
    {
        memset(head, 0, sizeof(head));
        memset(vis, false, sizeof(vis));
        cnt = 0;
        for (i = 1; i <= n; i++) scanf("%d", &val[i]);
        for (i = 1; i < n; i++) scanf("%d%d", &x, &y), ins(x, y), ins(y, x);
        ans1 = ans2 = inf;
        r = getroot(1);
        solve(r);
        if (ans1 != inf) printf("%d %d\n", ans1, ans2);
        else printf("No solution\n");
    }
    return 0;
}