About the Problem
Setter(s) | Chaithanya Shyam |
Tester(s) | Vichitr Gandas , Aswin Ashok |
Difficulty | Easy-Medium |
Topics | Tree, Graph, DFS, Expected Value |
Practice | Link |
Contest | Link |
Solution Idea
Let us look at the node u with the maximum value. We observe that value of u is either f(T_1) or f(T_2) depending on which tree it belongs to. Since it’s always part of the answer its a good idea to root the tree at this value.
Once we root the tree at u, let S = (\sum_{i}^{N} value_u - mx\_subtree[i]), where mx\_subtree[i] is the maximum value of the node in the sub-tree rooted at i.
Since each edge has equal probability of being selected the final answer is just \frac{S}{N-1} modulo 10^9+7 , we can calculate (N-1)^{-1} modulo 10^9 + 7 in \log(N) time.
Once we root the tree at u, then the answer becomes \sum_{i}^{N} value[u] - mx\_subtree[i] where mx\_subtree[i] is the maximum value of the node in the subtree rooted at i.
Complexity Analysis:
- Time Complexity: O(N)
- Space Complexity: O(N)
Approach 2 Using Euler Tour:
Find euler tour and then prefix and postfix maximum should work. For any edge (u,v) find prefix max upto u and suffix max upto v and add abs diff to the sum. Do this for all edges and find the sum. Expected value = sum / (n-1).
Codes
Setter's Code
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#include <bits/stdc++.h>
//#include <ext/pb_ds/assoc_container.hpp> //required
//#include <ext/pb_ds/tree_policy.hpp> //required
//using namespace __gnu_pbds; //required
using namespace std;
//template <typename T> using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
// ordered_set <int> s;
// s.find_by_order(k); returns the (k+1)th smallest element
// s.order_of_key(k); returns the number of elements in s strictly less than k
#define MOD (1000000000+7) // change as required
#define pb(x) push_back(x)
#define mp(x,y) make_pair(x,y)
#define all(x) x.begin(), x.end()
#define print(vec,l,r) for(int i = l; i <= r; i++) cout << vec[i] <<" "; cout << endl;
#define input(vec,N) for(int i = 0; i < (N); i++) cin >> vec[i];
#define debug(x) cerr << #x << " = " << (x) << endl;
#define leftmost_bit(x) (63-__builtin_clzll(x))
#define rightmost_bit(x) __builtin_ctzll(x) // count trailing zeros
#define set_bits(x) __builtin_popcountll(x)
#define pow2(i) (1LL << (i))
#define is_on(x, i) ((x) & pow2(i)) // state of the ith bit in x
#define set_on(x, i) ((x) | pow2(i)) // returns integer x with ith bit on
#define set_off(x, i) ((x) & ~pow2(i)) // returns integer x with ith bit off
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
typedef long long int ll;
// highly risky #defines
#define int ll // disable when you want to make code a bit faster
#define endl '\n' // disable when dealing with interactive problems
ll power(ll x, ll n, ll p){
ll res = 1;
if(n == 0) return 1;
if(n == 1) return x%p;
if(n%2 == 1) res = x%p;
ll temp = power(x,n/2,p);
return res*((temp*temp)%p)%p;
}
ll inv(ll a, ll p){
return power(a,p-2,p);
}
vector<int> dp;
vector<vector<int>> adj;
int ans = 0, max_weight = 0;
void dfs(int subroot, int par){
for(int child: adj[subroot]){
if(child == par) continue;
dfs(child, subroot);
dp[subroot] = max(dp[subroot], dp[child]);
}
ans += (max_weight-dp[subroot]);
}
void solve(){
// code starts from here
ans = 0;
int N;
cin >> N;
dp.clear();
adj.clear();
dp.resize(N);
adj.resize(N);
for(int i = 0; i < N-1; i++){
int t1, t2;
cin >> t1 >> t2;
t1--;t2--;
adj[t1].pb(t2);
adj[t2].pb(t1);
}
for(int i = 0; i < N; i++){
cin >> dp[i];
}
int root = max_element(all(dp)) - dp.begin();
max_weight = dp[root];
dfs(root, root);
ans %= MOD;
ans *= inv(N-1, MOD);
ans %= MOD;
cout << ans << endl;
}
clock_t startTime;
double getCurrentTime() {
return (double)(clock() - startTime) / CLOCKS_PER_SEC;
}
signed main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
//startTime = clock();
// mt19937_64 rnd(time(NULL));
int T = 1;
cin >> T;
while(T--){
solve();
}
//cerr << getCurrentTime() << endl;
return 0;
}
Tester's Code
/***************************************************
@author: vichitr
Compiled On: 27 Mar 2021
*****************************************************/
#include<bits/stdc++.h>
#define MAX 9223372036854775807
#define endl "\n"
#define ll long long
#define int long long
// #define double long double
#define pb push_back
#define pf pop_front
#define mp make_pair
#define ip pair<int, int>
#define F first
#define S second
#define loop(i,n) for(int i=0;i<n;i++)
#define loops(i,s,n) for(int i=s;i<=n;i++)
#define fast ios::sync_with_stdio(0); cin.tie(NULL); cout.tie(NULL)
using namespace std;
// #include <ext/pb_ds/assoc_container.hpp> // Common file
// #include <ext/pb_ds/tree_policy.hpp> // Including tree_order_statistics_node_updat
// using namespace __gnu_pbds;
// typedef tree<ip, null_type, less_equal<ip>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
// order_of_key (k) : Number of items strictly smaller than k .
// find_by_order(k) : K-th element in a set (counting from zero).
const ll MOD = 1e9+7;
const ll N = 1e5+7;
ll pwr(ll x, ll y)
{
ll r = 1LL;
while(y)
{
if(y&1)
r = (r * x) % MOD;
y >>= 1;
x = (x * x) % MOD;
}
return r;
}
int inv(int x)
{
return pwr(x, MOD-2ll);
}
int a[N];
vector<int> adj[N];
bool vis[N];
int mw, ans;
void dfs(int v){
if(vis[v]) return;
vis[v] = 1;
for(int u: adj[v]){
if(!vis[u]){
dfs(u);
a[v] = max(a[v], a[u]);
}
}
ans += abs(mw - a[v]);
ans %= MOD;
}
void solve(){
int n; cin>>n;
assert(n >= 2 and n <= 1e5);
loop(i, n){
vis[i] = 0;
adj[i].clear();
}
loop(i, n-1){
int u, v; cin >> u >> v;
assert(u >= 1 and u <= n);
assert(v >= 1 and v <= n);
u--, v--;
adj[u].pb(v);
adj[v].pb(u);
}
loop(i, n){
cin>>a[i];
assert(a[i] >= 1 and a[i] <= 1e9);
}
int root = 0;
for(int i=1;i<n;i++){
if(a[i] > a[root])
root = i;
}
mw = a[root];
ans = 0;
dfs(root);
ans *= inv(n-1);
ans %= MOD;
cout << ans <<'\n';
}
signed main()
{
fast;
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
int t=1;
cin >>t;
assert(t <= 5);
for(int i=1;i<=t;i++)
{
// cout<<"Case #"<<i<<": ";
solve();
}
return 0;
}
If you have used any other approach, share your approach in comments!
If you have any doubts, ask them in comments!