Tree Parenthesis - Editorial

About the Problem

Setter(s) Vaibhav Tulsyan
Tester(s) Balajiganpathi S, Aswin Ashok
Difficulty Hard
Topics Graph, Observations, String, DP, DFS
Practice Link
Contest Link

Solution Idea

Lets solve the problem for a string of parentheses first. For each index x , maintain cnt1[i] to denote the number of substrings ending at x from the left with exactly i unbalanced open parentheses and maintain cnt2[i] to denote the number of substrings ending at x from the right with exactly i unbalanced closing parentheses. For each index x , we can find the number of well-bracketed substrings by combining cnt1[i] and cnt2[i] for each i . When we move to index x+1 , depending on whether there is a ( or ) bracket, we either shift array cnt1[] by 1 position or shift cnt2[] , and then find the contribution for that index. Let X[i] for node u represent number of paths starting at node u ending at any of its descendants such that the number of unmatched parentheses is i . Let Y[i] for node u represent number of paths ending at node u starting at any of its descendants such that the number of unmatched parentheses is i . We need to match X[i] with Y[i] to get the number of well-bracketed across all children of node u . This can be done in linear time. After adding the contribution of all children of node u to the answer, the children’s X[] and Y[] can be merged using the small-to-large technique to get an overall amortized time complexity of O(N) .

Complexity Analysis:

  • Time Complexity: amortized O(n)
  • Space Complexity: O(n)

Codes

Setter's Code

Tester's Code
/* string coder = "Balajiganapathi S"; // Never give up!  */

//#define LOCAL
#ifdef LOCAL
#   define TRACE
#   define TEST
#else
#   define NDEBUG
//#   define FAST
#endif

#include<bits/stdc++.h>

using namespace std;

/* aliases */
using vi  = vector<int>;
using pi  = pair<int, int>;
using ll  = long long int;

/* shortcut macros */
#define mp              make_pair
#define fi              first
#define se              second
#define mt              make_tuple
#define gt(t, i)        get<i>(t)
#define all(x)          (x).begin(), (x).end()
#define ini(a, v)       memset(a, v, sizeof(a))
#define rep(i, s, n)    for(int i = (s), _##i = (n); i <= _##i; ++i)
#define re(i, s, n)     rep(i, (s), (n) - 1)
#define fo(i, n)        re(i, 0, n)
#define si(x)           (int((x).size()))
#define is1(mask,i)     (((mask) >> i) & 1)

/* trace macro */
#ifdef TRACE
#   define trace(v...)  {cerr << __func__ << ":" << __LINE__ << ": " ;_dt(#v, v);}
#else
#   define trace(...)
#endif

#ifdef TRACE
pi _gp(string s) {
    pi r(0, si(s) - 1);
    int p = 0, s1 = 0, s2 = 0, start = 1;
    fo(i, si(s)) {
        int x = (s1 | s2);
        if(s[i] == ' ' && start) {
            ++r.fi;
        } else {
            start = 0;
            if(s[i] == ',' && !p && !x) {
                r.se = i - 1;
                return r;
            }
            if(x && s[i] == '\\') ++i;
            else if(!x && s[i] == '(') ++p;
            else if(!x && s[i] == ')') --p;
            else if(!s2 && s[i] == '\'') s1 ^= 1;
            else if(!s1 && s[i] == '"') s2 ^= 1;
        }
    }
    return r;
}

template<typename H> void _dt(string u, H&& v) {
    pi p = _gp(u);
    cerr << u.substr(p.fi, p.se - p.fi + 1) << " = " << forward<H>(v) << " |" << endl;
}

template<typename H, typename ...T> void _dt(string u, H&& v, T&&... r) {
    pi p = _gp(u);
    cerr << u.substr(p.fi, p.se - p.fi + 1) << " = " << forward<H>(v) << " | ";
    _dt(u.substr(p.se + 2), forward<T>(r)...);
}

template<typename T>
ostream &operator <<(ostream &o, vector<T> v) { // print a vector
    o << "[";
    fo(i, si(v) - 1) o << v[i] << ", ";
    if(si(v)) o << v.back();
    o << "]";
    return o;
}

template<typename T>
ostream &operator <<(ostream &o, deque<T> v) { // print a deque
    o << "[";
    fo(i, si(v) - 1) o << v[i] << ", ";
    if(si(v)) o << v.back();
    o << "]";
    return o;
}

template<typename T1, typename T2>
ostream &operator <<(ostream &o, map<T1, T2> m) { // print a map
    o << "{";
    for(auto &p: m) {
        o << " (" << p.fi << " -> " << p.se << ")";
    }
    o << " }";
    return o;
}

template<typename T>
ostream &operator <<(ostream &o, set<T> s) { // print a set
    o << "{";
    bool first = true;
    for(auto &entry: s) {
        if(!first) o << ", ";

        o << entry;
        first = false;
    }
    o << "}";
    return o;
}

template <size_t n, typename... T>
typename enable_if<(n >= sizeof...(T))>::type
    print_tuple(ostream&, const tuple<T...>&) {}

template <size_t n, typename... T>
typename enable_if<(n < sizeof...(T))>::type
    print_tuple(ostream& os, const tuple<T...>& tup) {
    if (n != 0)
        os << ", ";
    os << get<n>(tup);
    print_tuple<n+1>(os, tup);
}

template <typename... T>
ostream& operator<<(ostream& os, const tuple<T...>& tup) { // print a tuple
    os << "("; print_tuple<0>(os, tup); return os << ")"; } template <typename T1, typename T2>
ostream& operator<<(ostream& os, const pair<T1, T2>& p) { // print a pair
    return os << "(" << p.fi << ", " << p.se << ")";
}
#endif

/* util functions */
template<typename T1, typename T2, typename T3>
T1 modpow(T1 _a, T2 p, T3 mod) {
    assert(p >= 0);
    ll ret = 1, a = _a;

#ifndef FAST
    if(a < 0) {
        a %= mod;
        a += mod;
    }

    if(a >= mod) {
        a %= mod;
    }
#endif

    for(; p > 0; p /= 2) {
        if(p & 1) ret = ret * a % mod;
        a = a * a % mod;
    }

    return ret;
}

#define x1 _asdfzx1
#define y1 _ysfdzy1

/* constants */
constexpr int dx[] = {-1, 0, 1, 0, 1, 1, -1, -1};
constexpr int dy[] = {0, -1, 0, 1, 1, -1, 1, -1};
constexpr auto PI  = 3.14159265358979323846L;
constexpr auto oo  = numeric_limits<int>::max() / 2 - 2;
constexpr auto eps = 1e-6;
constexpr auto mod = 1000000007;

/* code */
constexpr int mx_n = 1000006;


int n;

vector<pi> ch[mx_n];
int par[mx_n];
string chars;

class Result {
    public:
    deque<ll> *fwd, *rev;
    ll total;
    Result(): total(0) {
        fwd = new deque<ll>();
        rev = new deque<ll>();
    }
    Result(const Result& res) {
        fwd = res.fwd;
        rev = res.rev;
        total = res.total;
    }
};

Result merge(vector<Result>& ds) {
    trace(ds.size());
    Result res;
    if(ds.empty()) return res;
    // Sum up total
    res.total = 0;
    for(Result&r: ds) {
        res.total += r.total;
        ll fc = 0, rc = 0;
        if(!r.fwd->empty()) res.total += r.fwd->at(0);
        if(!r.rev->empty()) res.total += r.rev->at(0);
    }

    // Count across node
    int fwdMax = 0, revMax = 0;
    for(int i = 0; i < ds.size(); ++i) {
        fwdMax = max(fwdMax, (int)ds[i].fwd->size());
        revMax = max(revMax, (int)ds[i].rev->size());
    }
    //trace(fwdMax, fwdMax2, revMax, revMax2);

    int calcTill = min(fwdMax, revMax);
    int till = ds.size();
    for(int j = 0; j < calcTill; ++j) {
        ll fsum = 0, rsum = 0, frsum = 0;
        for(int i = 0; i < till; ++i) {
            ll rc = 0, fc = 0;
            bool reach = false;
            if(ds[i].fwd->size() > j) {
                fc = (*ds[i].fwd)[j];
                fsum += fc;
                reach = true;
            }
            if(ds[i].rev->size() > j) {
                rc = (*ds[i].rev)[j];
                rsum += rc;
                reach = true;
            }
            if(!reach) {
                swap(ds[i], ds[--till]);
                --i;
                continue;
            }
            frsum += fc * rc;
        }
        res.total += fsum * rsum - frsum;
    }
    //trace(till);

    // Merge fwd
    int f2max = 0;
    for(int i = 1; i < ds.size(); ++i) {
        if(ds[i].fwd->size() > ds[0].fwd->size()) {
            f2max = ds[0].fwd->size();
            swap(ds[i], ds[0]);
        } else if(ds[i].fwd->size() > f2max) {
            f2max = ds[i].fwd->size();
        }
    }
    //trace(f2max);

    res.fwd = ds[0].fwd;
    #ifdef TRACE
    cerr << "res.fwd" << (*res.fwd) << endl;
    #endif
    till = ds.size();
    for(int j = 0; j < f2max; ++j) {
        for(int i = 1; i < till; ++i) {
            if(ds[i].fwd->size() <= j) {
                swap(ds[i], ds[--till]);
                --i;
                continue;
            }
            res.fwd->at(j) += ((*ds[i].fwd)[j]);
        }
        #ifdef TRACE
        cerr << "res.fwd->" << (*res.fwd) << endl;
        #endif
    }

    // Merge rev
    int r2max = 0;
    for(int i = 1; i < ds.size(); ++i) {
        if(ds[i].rev->size() > ds[0].rev->size()) {
            r2max = ds[0].rev->size();
            swap(ds[i], ds[0]);
        } else if(ds[i].rev->size() > r2max) {
            r2max = ds[i].rev->size();
        }
    }
    //trace(r2max);

    res.rev = ds[0].rev;
    till = ds.size();
    for(int j = 0; j < r2max; ++j) {
        for(int i = 1; i < till; ++i) {
            if(ds[i].rev->size() <= j) {
                swap(ds[i], ds[--till]);
                --i;
                continue;
            }
            res.rev->at(j) += (*ds[i].rev)[j];
        }
    }
    //--res.total;
    //trace("Done");

    return res;
}


/*
Result bruteSolve(int x) {
    Result res;
    bruteFwd(x, res.fwd, 0);
    bruteRev(x, res.rev, 0);
}
*/

Result solve(int x) {
    trace(x);
    vector<Result> chs;
    for(auto edge: ch[x]) {
        trace(edge.fi, edge.se);
        Result tmp = solve(edge.first);
        if(edge.second == ')') {
            if(!tmp.fwd->empty()) tmp.fwd->pop_front();
            if(tmp.rev->empty()) tmp.rev->push_front(0);
            ++tmp.rev->front();
            tmp.rev->push_front(0);
        } else {
            if(tmp.fwd->empty()) tmp.fwd->push_front(0);
            ++tmp.fwd->front();
            tmp.fwd->push_front(0);
            if(!tmp.rev->empty()) tmp.rev->pop_front();
        }

            #ifdef TRACE
            cerr << "fwd[" << x << " " << edge.fi << "]: " << (*tmp.fwd) << endl;
            cerr << "rev[" << x << " " << edge.fi << "]: " << (*tmp.rev) << endl;
            #endif
        chs.push_back(tmp);
    }

    Result res = merge(chs);
    trace(x, res.total);
    #ifdef TRACE
    cerr << "fwd: " << (*res.fwd) << endl;
    cerr << "rev: " << (*res.rev) << endl;
    #endif
    //trace((*res.fwd));
    return res;
}

ll solve() {
    return solve(0).total;
}


/*
For single string:
ll cntStr(string s) {
    ll ans = 0;
    deque<ll> q;
    q.push_front(1);
    for(char c: s) {
        if(c == '(') {
            q.push_front(0);
        } else {
            q.pop_front();
        }
        if(q.empty()) q.push_front(0);
        ++q.front();
        ans += q[0]-1;
    }

    return ans;

}

ll bruteCnt(string s) {
    ll cnt = 0;
    for(int i = 0; i < s.size(); ++i) {
        int open = 0;
        for(int j = i; j < s.size(); ++j) {
            if(s[j] == '(') ++open;
            else --open;
            if(open < 0) break;
            if(open == 0) ++cnt;
        }
    }
    return cnt;
}
*/

int main() {
    int t;
    cin >> t;
    while(t--) {
        cin >> n;
        trace(n);
        for(int i = 0; i < n; ++i) {
            cin >> par[i];
            ch[i].clear();
        }
        cin >> chars;
        trace(chars);
        //int n = rand() % 1000 + 1;
        //for(int i = 0; i < n; ++i) chars += ((rand() % 2)? '(': ')');
        //ll brute = bruteCnt();
        //cerr << ans << " " << brute << endl;
        //assert(ans == brute);
        for(int i = 1; i < n; ++i) if(par[i] != -1) {
            ch[par[i]].emplace_back(i, chars[i-1]);
        }
        ll ans = solve();
        cout << ans << endl;

    }
    return 0;
}

If you have used any other approach, share your approach in comments!
If you have any doubts, ask them in comments!

2 Likes