Beauty of the Tree - Editorial

About the Problem

Setter(s) Chaithanya Shyam
Tester(s) Vichitr Gandas , Aswin Ashok
Difficulty Easy-Medium
Topics Tree, Graph, DFS, Expected Value
Practice Link
Contest Link

Solution Idea

Let us look at the node u with the maximum value. We observe that value of u is either f(T_1) or f(T_2) depending on which tree it belongs to. Since it’s always part of the answer its a good idea to root the tree at this value.

Once we root the tree at u, let S = (\sum_{i}^{N} value_u - mx\_subtree[i]), where mx\_subtree[i] is the maximum value of the node in the sub-tree rooted at i.

Since each edge has equal probability of being selected the final answer is just \frac{S}{N-1} modulo 10^9+7 , we can calculate (N-1)^{-1} modulo 10^9 + 7 in \log(N) time.

Once we root the tree at u, then the answer becomes \sum_{i}^{N} value[u] - mx\_subtree[i] where mx\_subtree[i] is the maximum value of the node in the subtree rooted at i.

Complexity Analysis:

  • Time Complexity: O(N)
  • Space Complexity: O(N)

Approach 2 Using Euler Tour:

Find euler tour and then prefix and postfix maximum should work. For any edge (u,v) find prefix max upto u and suffix max upto v and add abs diff to the sum. Do this for all edges and find the sum. Expected value = sum / (n-1).

Codes

Setter's Code
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")

#include <bits/stdc++.h>
//#include <ext/pb_ds/assoc_container.hpp> //required
//#include <ext/pb_ds/tree_policy.hpp> //required

//using namespace __gnu_pbds; //required 
using namespace std;
//template <typename T> using ordered_set =  tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>; 

// ordered_set <int> s;
// s.find_by_order(k); returns the (k+1)th smallest element
// s.order_of_key(k); returns the number of elements in s strictly less than k

#define MOD              (1000000000+7) // change as required
#define pb(x)            push_back(x)
#define mp(x,y)          make_pair(x,y)
#define all(x)           x.begin(), x.end()
#define print(vec,l,r)   for(int i = l; i <= r; i++) cout << vec[i] <<" "; cout << endl;
#define input(vec,N)     for(int i = 0; i < (N); i++) cin >> vec[i];
#define debug(x)         cerr << #x << " = " << (x) << endl;
#define leftmost_bit(x)  (63-__builtin_clzll(x))
#define rightmost_bit(x) __builtin_ctzll(x) // count trailing zeros
#define set_bits(x)      __builtin_popcountll(x) 
#define pow2(i)          (1LL << (i))
#define is_on(x, i)      ((x) & pow2(i)) // state of the ith bit in x
#define set_on(x, i)     ((x) | pow2(i)) // returns integer x with ith bit on
#define set_off(x, i)    ((x) & ~pow2(i)) // returns integer x with ith bit off
  
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

typedef long long int ll;

// highly risky #defines
#define int ll // disable when you want to make code a bit faster
#define endl '\n' // disable when dealing with interactive problems

ll power(ll x, ll n, ll p){
	ll res = 1;
	if(n == 0) return 1;
	if(n == 1) return x%p;
	if(n%2 == 1) res = x%p;

	ll temp = power(x,n/2,p);
	return res*((temp*temp)%p)%p;
}
ll inv(ll a, ll p){
	return power(a,p-2,p);
}

vector<int> dp;
vector<vector<int>> adj;
int ans = 0, max_weight = 0;

void dfs(int subroot, int par){
	for(int child: adj[subroot]){
		if(child == par) continue;

		dfs(child, subroot);
		dp[subroot] = max(dp[subroot], dp[child]);
	}
	ans += (max_weight-dp[subroot]);
}

void solve(){
	// code starts from here
	ans = 0; 

	int N;
	cin >> N;

	dp.clear();
	adj.clear();

	dp.resize(N);
	adj.resize(N);

	for(int i = 0; i < N-1; i++){
		int t1, t2;
		cin >> t1 >> t2;

		t1--;t2--;

		adj[t1].pb(t2);
		adj[t2].pb(t1);
	}

	for(int i = 0; i < N; i++){
		cin >> dp[i];
	}

	int root = max_element(all(dp)) - dp.begin();
	max_weight = dp[root];

	dfs(root, root);

	ans %= MOD;
	ans *= inv(N-1, MOD);
	ans %= MOD;

	cout << ans << endl;
}


clock_t startTime;
double getCurrentTime() {
	return (double)(clock() - startTime) / CLOCKS_PER_SEC;
}

signed main(){
 	ios_base::sync_with_stdio(false);
    cin.tie(NULL);
	//startTime = clock();
	// mt19937_64 rnd(time(NULL));
	
	int T = 1;
	cin >> T;

	while(T--){
		solve();
	}
	
	//cerr << getCurrentTime() << endl;
	return 0;
}
Tester's Code
/***************************************************

@author: vichitr
Compiled On: 27 Mar 2021

*****************************************************/
#include<bits/stdc++.h>
#define MAX 9223372036854775807
#define endl "\n"
#define ll long long
#define int long long
// #define double long double
#define pb push_back
#define pf pop_front
#define mp make_pair
#define ip pair<int, int>
#define F first
#define S second

#define loop(i,n) for(int i=0;i<n;i++)
#define loops(i,s,n) for(int i=s;i<=n;i++)
#define fast ios::sync_with_stdio(0); cin.tie(NULL); cout.tie(NULL)
using namespace std;

// #include <ext/pb_ds/assoc_container.hpp> // Common file
// #include <ext/pb_ds/tree_policy.hpp>     // Including tree_order_statistics_node_updat
// using namespace __gnu_pbds;
// typedef tree<ip, null_type, less_equal<ip>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;

// order_of_key (k) : Number of items strictly smaller than k .
// find_by_order(k) : K-th element in a set (counting from zero).

const ll MOD = 1e9+7;
const ll N = 1e5+7;

ll pwr(ll x, ll y)
{
    ll r = 1LL;
    while(y)
    {
        if(y&1)
            r = (r * x) % MOD;
        y >>= 1;
        x = (x * x) % MOD;
    }
    return r;
}

int inv(int x)
{
    return pwr(x, MOD-2ll);
}

int a[N];
vector<int> adj[N];
bool vis[N];
int mw, ans;

void dfs(int v){
  	if(vis[v]) return;
    vis[v] = 1;
    for(int u: adj[v]){
        if(!vis[u]){
            dfs(u);
            a[v] = max(a[v], a[u]);
        }
    }
    ans += abs(mw - a[v]);
  	ans %= MOD;
}

void solve(){
    int n; cin>>n;
    assert(n >= 2 and n <= 1e5);
    loop(i, n){
        vis[i] = 0;
        adj[i].clear();
    }
    loop(i, n-1){
        int u, v; cin >> u >> v;
        assert(u >= 1 and u <= n);
        assert(v >= 1 and v <= n);
        u--, v--;
        adj[u].pb(v);
        adj[v].pb(u);
    }

    loop(i, n){
        cin>>a[i];
        assert(a[i] >= 1 and a[i] <= 1e9);
    }
    int root = 0;
    for(int i=1;i<n;i++){
        if(a[i] > a[root])
            root = i;
    }
    mw = a[root];
    ans = 0;
    dfs(root);
    ans *= inv(n-1);
    ans %= MOD;
    cout << ans <<'\n';
}

signed main()
{
    fast;
#ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
#endif
    
    int t=1;
    cin >>t;
    assert(t <= 5);
    for(int i=1;i<=t;i++)
    {
        // cout<<"Case #"<<i<<": ";
        solve();
    }
    return 0;
}

If you have used any other approach, share your approach in comments!
If you have any doubts, ask them in comments!

3 Likes