About the Problem
Setter(s) | Vaibhav Tulsyan |
Tester(s) | Balajiganpathi S, Aswin Ashok |
Difficulty | Hard |
Topics | Graph, Observations, String, DP, DFS |
Practice | Link |
Contest | Link |
Solution Idea
Lets solve the problem for a string of parentheses first. For each index x , maintain cnt1[i] to denote the number of substrings ending at x from the left with exactly i unbalanced open parentheses and maintain cnt2[i] to denote the number of substrings ending at x from the right with exactly i unbalanced closing parentheses. For each index x , we can find the number of well-bracketed substrings by combining cnt1[i] and cnt2[i] for each i . When we move to index x+1 , depending on whether there is a ( or ) bracket, we either shift array cnt1[] by 1 position or shift cnt2[] , and then find the contribution for that index. Let X[i] for node u represent number of paths starting at node u ending at any of its descendants such that the number of unmatched parentheses is i . Let Y[i] for node u represent number of paths ending at node u starting at any of its descendants such that the number of unmatched parentheses is i . We need to match X[i] with Y[i] to get the number of well-bracketed across all children of node u . This can be done in linear time. After adding the contribution of all children of node u to the answer, the children’s X[] and Y[] can be merged using the small-to-large technique to get an overall amortized time complexity of O(N) .
Complexity Analysis:
- Time Complexity: amortized O(n)
- Space Complexity: O(n)
Codes
Setter's Code
Tester's Code
/* string coder = "Balajiganapathi S"; // Never give up! */
//#define LOCAL
#ifdef LOCAL
# define TRACE
# define TEST
#else
# define NDEBUG
//# define FAST
#endif
#include<bits/stdc++.h>
using namespace std;
/* aliases */
using vi = vector<int>;
using pi = pair<int, int>;
using ll = long long int;
/* shortcut macros */
#define mp make_pair
#define fi first
#define se second
#define mt make_tuple
#define gt(t, i) get<i>(t)
#define all(x) (x).begin(), (x).end()
#define ini(a, v) memset(a, v, sizeof(a))
#define rep(i, s, n) for(int i = (s), _##i = (n); i <= _##i; ++i)
#define re(i, s, n) rep(i, (s), (n) - 1)
#define fo(i, n) re(i, 0, n)
#define si(x) (int((x).size()))
#define is1(mask,i) (((mask) >> i) & 1)
/* trace macro */
#ifdef TRACE
# define trace(v...) {cerr << __func__ << ":" << __LINE__ << ": " ;_dt(#v, v);}
#else
# define trace(...)
#endif
#ifdef TRACE
pi _gp(string s) {
pi r(0, si(s) - 1);
int p = 0, s1 = 0, s2 = 0, start = 1;
fo(i, si(s)) {
int x = (s1 | s2);
if(s[i] == ' ' && start) {
++r.fi;
} else {
start = 0;
if(s[i] == ',' && !p && !x) {
r.se = i - 1;
return r;
}
if(x && s[i] == '\\') ++i;
else if(!x && s[i] == '(') ++p;
else if(!x && s[i] == ')') --p;
else if(!s2 && s[i] == '\'') s1 ^= 1;
else if(!s1 && s[i] == '"') s2 ^= 1;
}
}
return r;
}
template<typename H> void _dt(string u, H&& v) {
pi p = _gp(u);
cerr << u.substr(p.fi, p.se - p.fi + 1) << " = " << forward<H>(v) << " |" << endl;
}
template<typename H, typename ...T> void _dt(string u, H&& v, T&&... r) {
pi p = _gp(u);
cerr << u.substr(p.fi, p.se - p.fi + 1) << " = " << forward<H>(v) << " | ";
_dt(u.substr(p.se + 2), forward<T>(r)...);
}
template<typename T>
ostream &operator <<(ostream &o, vector<T> v) { // print a vector
o << "[";
fo(i, si(v) - 1) o << v[i] << ", ";
if(si(v)) o << v.back();
o << "]";
return o;
}
template<typename T>
ostream &operator <<(ostream &o, deque<T> v) { // print a deque
o << "[";
fo(i, si(v) - 1) o << v[i] << ", ";
if(si(v)) o << v.back();
o << "]";
return o;
}
template<typename T1, typename T2>
ostream &operator <<(ostream &o, map<T1, T2> m) { // print a map
o << "{";
for(auto &p: m) {
o << " (" << p.fi << " -> " << p.se << ")";
}
o << " }";
return o;
}
template<typename T>
ostream &operator <<(ostream &o, set<T> s) { // print a set
o << "{";
bool first = true;
for(auto &entry: s) {
if(!first) o << ", ";
o << entry;
first = false;
}
o << "}";
return o;
}
template <size_t n, typename... T>
typename enable_if<(n >= sizeof...(T))>::type
print_tuple(ostream&, const tuple<T...>&) {}
template <size_t n, typename... T>
typename enable_if<(n < sizeof...(T))>::type
print_tuple(ostream& os, const tuple<T...>& tup) {
if (n != 0)
os << ", ";
os << get<n>(tup);
print_tuple<n+1>(os, tup);
}
template <typename... T>
ostream& operator<<(ostream& os, const tuple<T...>& tup) { // print a tuple
os << "("; print_tuple<0>(os, tup); return os << ")"; } template <typename T1, typename T2>
ostream& operator<<(ostream& os, const pair<T1, T2>& p) { // print a pair
return os << "(" << p.fi << ", " << p.se << ")";
}
#endif
/* util functions */
template<typename T1, typename T2, typename T3>
T1 modpow(T1 _a, T2 p, T3 mod) {
assert(p >= 0);
ll ret = 1, a = _a;
#ifndef FAST
if(a < 0) {
a %= mod;
a += mod;
}
if(a >= mod) {
a %= mod;
}
#endif
for(; p > 0; p /= 2) {
if(p & 1) ret = ret * a % mod;
a = a * a % mod;
}
return ret;
}
#define x1 _asdfzx1
#define y1 _ysfdzy1
/* constants */
constexpr int dx[] = {-1, 0, 1, 0, 1, 1, -1, -1};
constexpr int dy[] = {0, -1, 0, 1, 1, -1, 1, -1};
constexpr auto PI = 3.14159265358979323846L;
constexpr auto oo = numeric_limits<int>::max() / 2 - 2;
constexpr auto eps = 1e-6;
constexpr auto mod = 1000000007;
/* code */
constexpr int mx_n = 1000006;
int n;
vector<pi> ch[mx_n];
int par[mx_n];
string chars;
class Result {
public:
deque<ll> *fwd, *rev;
ll total;
Result(): total(0) {
fwd = new deque<ll>();
rev = new deque<ll>();
}
Result(const Result& res) {
fwd = res.fwd;
rev = res.rev;
total = res.total;
}
};
Result merge(vector<Result>& ds) {
trace(ds.size());
Result res;
if(ds.empty()) return res;
// Sum up total
res.total = 0;
for(Result&r: ds) {
res.total += r.total;
ll fc = 0, rc = 0;
if(!r.fwd->empty()) res.total += r.fwd->at(0);
if(!r.rev->empty()) res.total += r.rev->at(0);
}
// Count across node
int fwdMax = 0, revMax = 0;
for(int i = 0; i < ds.size(); ++i) {
fwdMax = max(fwdMax, (int)ds[i].fwd->size());
revMax = max(revMax, (int)ds[i].rev->size());
}
//trace(fwdMax, fwdMax2, revMax, revMax2);
int calcTill = min(fwdMax, revMax);
int till = ds.size();
for(int j = 0; j < calcTill; ++j) {
ll fsum = 0, rsum = 0, frsum = 0;
for(int i = 0; i < till; ++i) {
ll rc = 0, fc = 0;
bool reach = false;
if(ds[i].fwd->size() > j) {
fc = (*ds[i].fwd)[j];
fsum += fc;
reach = true;
}
if(ds[i].rev->size() > j) {
rc = (*ds[i].rev)[j];
rsum += rc;
reach = true;
}
if(!reach) {
swap(ds[i], ds[--till]);
--i;
continue;
}
frsum += fc * rc;
}
res.total += fsum * rsum - frsum;
}
//trace(till);
// Merge fwd
int f2max = 0;
for(int i = 1; i < ds.size(); ++i) {
if(ds[i].fwd->size() > ds[0].fwd->size()) {
f2max = ds[0].fwd->size();
swap(ds[i], ds[0]);
} else if(ds[i].fwd->size() > f2max) {
f2max = ds[i].fwd->size();
}
}
//trace(f2max);
res.fwd = ds[0].fwd;
#ifdef TRACE
cerr << "res.fwd" << (*res.fwd) << endl;
#endif
till = ds.size();
for(int j = 0; j < f2max; ++j) {
for(int i = 1; i < till; ++i) {
if(ds[i].fwd->size() <= j) {
swap(ds[i], ds[--till]);
--i;
continue;
}
res.fwd->at(j) += ((*ds[i].fwd)[j]);
}
#ifdef TRACE
cerr << "res.fwd->" << (*res.fwd) << endl;
#endif
}
// Merge rev
int r2max = 0;
for(int i = 1; i < ds.size(); ++i) {
if(ds[i].rev->size() > ds[0].rev->size()) {
r2max = ds[0].rev->size();
swap(ds[i], ds[0]);
} else if(ds[i].rev->size() > r2max) {
r2max = ds[i].rev->size();
}
}
//trace(r2max);
res.rev = ds[0].rev;
till = ds.size();
for(int j = 0; j < r2max; ++j) {
for(int i = 1; i < till; ++i) {
if(ds[i].rev->size() <= j) {
swap(ds[i], ds[--till]);
--i;
continue;
}
res.rev->at(j) += (*ds[i].rev)[j];
}
}
//--res.total;
//trace("Done");
return res;
}
/*
Result bruteSolve(int x) {
Result res;
bruteFwd(x, res.fwd, 0);
bruteRev(x, res.rev, 0);
}
*/
Result solve(int x) {
trace(x);
vector<Result> chs;
for(auto edge: ch[x]) {
trace(edge.fi, edge.se);
Result tmp = solve(edge.first);
if(edge.second == ')') {
if(!tmp.fwd->empty()) tmp.fwd->pop_front();
if(tmp.rev->empty()) tmp.rev->push_front(0);
++tmp.rev->front();
tmp.rev->push_front(0);
} else {
if(tmp.fwd->empty()) tmp.fwd->push_front(0);
++tmp.fwd->front();
tmp.fwd->push_front(0);
if(!tmp.rev->empty()) tmp.rev->pop_front();
}
#ifdef TRACE
cerr << "fwd[" << x << " " << edge.fi << "]: " << (*tmp.fwd) << endl;
cerr << "rev[" << x << " " << edge.fi << "]: " << (*tmp.rev) << endl;
#endif
chs.push_back(tmp);
}
Result res = merge(chs);
trace(x, res.total);
#ifdef TRACE
cerr << "fwd: " << (*res.fwd) << endl;
cerr << "rev: " << (*res.rev) << endl;
#endif
//trace((*res.fwd));
return res;
}
ll solve() {
return solve(0).total;
}
/*
For single string:
ll cntStr(string s) {
ll ans = 0;
deque<ll> q;
q.push_front(1);
for(char c: s) {
if(c == '(') {
q.push_front(0);
} else {
q.pop_front();
}
if(q.empty()) q.push_front(0);
++q.front();
ans += q[0]-1;
}
return ans;
}
ll bruteCnt(string s) {
ll cnt = 0;
for(int i = 0; i < s.size(); ++i) {
int open = 0;
for(int j = i; j < s.size(); ++j) {
if(s[j] == '(') ++open;
else --open;
if(open < 0) break;
if(open == 0) ++cnt;
}
}
return cnt;
}
*/
int main() {
int t;
cin >> t;
while(t--) {
cin >> n;
trace(n);
for(int i = 0; i < n; ++i) {
cin >> par[i];
ch[i].clear();
}
cin >> chars;
trace(chars);
//int n = rand() % 1000 + 1;
//for(int i = 0; i < n; ++i) chars += ((rand() % 2)? '(': ')');
//ll brute = bruteCnt();
//cerr << ans << " " << brute << endl;
//assert(ans == brute);
for(int i = 1; i < n; ++i) if(par[i] != -1) {
ch[par[i]].emplace_back(i, chars[i-1]);
}
ll ans = solve();
cout << ans << endl;
}
return 0;
}
If you have used any other approach, share your approach in comments!
If you have any doubts, ask them in comments!