Counting On a Tree
Problem Statement :
Taylor loves trees, and this new challenge has him stumped! Consider a tree, t, consisting of n nodes. Each node is numbered from 1 to n, and each node i has an integer, ci, attached to it. A query on tree t takes the form w x y z. To process a query, you must print the count of ordered pairs of integers ( i , j ) such that the following four conditions are all satisfied: the path from node to node . path from node to node . Given t and q queries, process each query in order, printing the pair count for each query on a new line. Input Format The first line contains two space-separated integers describing the respective values of n (the number of nodes) and q (the number of queries). The second line contains space-separated integers describing the respective values of each node (i.e., c1 , c2, . . . , cn ). Each of the n - 1 subsequent lines contains two space-separated integers, u and v , defining a bidirectional edge between nodes u and v. Each of the q subsequent lines contains a w x y z query, defined above. Constraints 1 <= n <= 10^5 1 <= q <= 50000 1 <= ci <= 10^9 1 <= u, v , w, x, y , z <= n Scoring for this problem is Binary, that means you have to pass all the test cases to get a positive score. Output Format For each query, print the count of ordered pairs of integers satisfying the four given conditions on a new line.
Solution :
Solution in C :
In C++ :
#include <bits/stdc++.h>
using namespace std;
const int kMaxN = 300005;
vector<int> G[kMaxN];
int Color[kMaxN], Linear[kMaxN], Beg[kMaxN],
End[kMaxN], Depth[kMaxN];
long long Answer[kMaxN];
int n, colors, timer;
int Parent[20][kMaxN];
void DFS(int node) {
Beg[node] = ++timer;
Linear[timer] = +Color[node];
for(auto vec : G[node]) {
if(!Beg[vec]) {
Parent[0][vec] = node;
Depth[vec] = Depth[node] + 1;
for(int i = 1; Parent[i - 1][vec]; ++i) {
Parent[i][vec] = Parent[i - 1][Parent[i - 1][vec]];
End[node] = ++timer;
Linear[timer] = -Color[node];
int LCA(int a, int b) {
if(Depth[a] < Depth[b])
swap(a, b);
int d = Depth[a] - Depth[b];
for(int i = 0; d; ++i, d /= 2)
if(d % 2)
a = Parent[i][a];
if(a == b)
return a;
for(int i = 19; i >= 0; --i)
if(Parent[i][a] != Parent[i][b])
a = Parent[i][a], b = Parent[i][b];
assert(Parent[0][a] == Parent[0][b]);
return Parent[0][a];
struct Query {
int a, b, i, sgn;
vector<Query> Queries;
void AddQuery(int a, int b, int i, int sgn) {
if(a == 0 || b == 0) return;
Queries.push_back(Query {a, b, i, sgn});
//cerr << "Added: " << a << " " << b << " at ind " << i << " with sign " << sgn << '\n';
void AddQuery(int a, int b, int c, int d, int i) {
AddQuery(a, c, i, 1);
AddQuery(b, c, i, -1);
AddQuery(a, d, i, -1);
AddQuery(b, d, i, 1);
int Count[2][kMaxN];
long long global_ans;
void Add(int at, int col) {
int pos = abs(col);
int delta = col / pos;
global_ans -= 1LL * Count[0][pos] * Count[1][pos];
Count[at][pos] += delta;
global_ans += 1LL * Count[0][pos] * Count[1][pos];
void SolveQueries() {
for(auto &q : Queries) {
Answer[q.i] -= q.sgn * (1 + Depth[LCA(q.a, q.b)]);
q.a = Beg[q.a];
q.b = Beg[q.b];
if(q.a > q.b)
swap(q.a, q.b);
const int kMagic = 256;
sort(Queries.begin(), Queries.end(), [](Query &a, Query &b) {
if(a.a / kMagic == b.a / kMagic)
return a.b < b.b;
return a.a < b.a;
int b = 0, e = 0;
for(auto &q : Queries) {
while(b < q.a) Add(0, Linear[++b]);
while(e < q.b) Add(1, Linear[++e]);
while(b > q.a) Add(0, -Linear[b--]);
while(e > q.b) Add(1, -Linear[e--]);
Answer[q.i] += global_ans * q.sgn;
int main() {
int q;
cin >> n >> q;
map<int, int> normMap;
for(int i = 1; i <= n; ++i) {
cin >> Color[i];
auto &norm = normMap[Color[i]];
if(norm == 0)
norm = ++colors;
Color[i] = norm;
for(int i = 2; i <= n; ++i) {
int a, b;
cin >> a >> b;
Depth[0] = -1;
for(int i = 1; i <= timer; ++i)
cerr << Linear[i] << " ";
cerr << endl;
for(int i = 1; i <= q; ++i) {
int a1, b1, a2, b2;
cin >> a1 >> b1 >> a2 >> b2;
int l1 = LCA(a1, b1);
int l2 = LCA(a2, b2);
AddQuery(a1, Parent[0][l1], a2, Parent[0][l2], i);
AddQuery(a1, Parent[0][l1], b2, l2, i);
AddQuery(b1, l1, a2, Parent[0][l2], i);
AddQuery(b1, l1, b2, l2, i);
for(int i = 1; i <= q; ++i)
cout << Answer[i] << "\n";
cout << endl;
return 0;
In Java :
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;
public class Solution {
// Complete the solve function below.
static int[] solve(int[] c, int[][] tree, int[][] queries) {
private static final Scanner scanner = new Scanner(;
public static void main(String[] args) throws IOException {
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));
String[] nq = scanner.nextLine().split(" ");
int n = Integer.parseInt(nq[0]);
int q = Integer.parseInt(nq[1]);
int[] c = new int[n];
String[] cItems = scanner.nextLine().split(" ");
for (int cItr = 0; cItr < n; cItr++) {
int cItem = Integer.parseInt(cItems[cItr]);
c[cItr] = cItem;
int[][] tree = new int[n-1][2];
for (int treeRowItr = 0; treeRowItr < n-1; treeRowItr++) {
String[] treeRowItems = scanner.nextLine().split(" ");
for (int treeColumnItr = 0; treeColumnItr < 2; treeColumnItr++) {
int treeItem = Integer.parseInt(treeRowItems[treeColumnItr]);
tree[treeRowItr][treeColumnItr] = treeItem;
int[][] queries = new int[q][4];
for (int queriesRowItr = 0; queriesRowItr < q; queriesRowItr++) {
String[] queriesRowItems = scanner.nextLine().split(" ");
for (int queriesColumnItr = 0; queriesColumnItr < 4; queriesColumnItr++) {
int queriesItem = Integer.parseInt(queriesRowItems[queriesColumnItr]);
queries[queriesRowItr][queriesColumnItr] = queriesItem;
int[] result = solve(c, tree, queries);
for (int resultItr = 0; resultItr < result.length; resultItr++) {
if (resultItr != result.length - 1) {
In C :
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>
#define floor_log2_X86(self) (__builtin_clz(self) ^ 31U)
#define floor_log2 floor_log2_X86
void heap_sort(unsigned *self, unsigned *weights, unsigned length) {
at = length >> 1,
for (self--; at; self[node >> 1] = member) {
member = self[at];
for (node = at-- << 1; node <= length; node <<= 1) {
node |= (node < length) && (weights[self[node]] < weights[self[node | 1]]);
if (weights[self[node]] < weights[member])
break ;
self[node >> 1] = self[node];
for (; length > 1; self[at >> 1] = member) {
member = self[length];
self[length--] = self[1];
for (at = 2; at <= length; at <<= 1) {
at |= (at < length) && (weights[self[at]] < weights[self[at | 1]]);
if (weights[self[at]] < weights[member])
break ;
self[at >> 1] = self[at];
void compress(unsigned length, unsigned values[length]) {
unsigned long sum = 0x0000000100000000UL;
for (at = 0; at < (length >> 1); sum += 0x0000000200000002UL)
((unsigned long *)order)[at++] = sum;
order[length - 1] = length - 1;
heap_sort(order, values, length);
unsigned roots[length], seen = 1, max = 0, others;
for (roots[at = 0] = -1U; at < length; roots[seen++] = at - 1) {
for (others = at; (at < length) && values[order[at]] == values[order[others]]; at++);
if (max < (at - others))
max = (at - others);
indices[max + 1],
memset(indices, 0, sizeof(indices));
for (at = 0; ++at < seen; indices[roots[at] - roots[at - 1]]++);
for (at = max; at--; indices[at] += indices[at + 1]);
for (at = seen; --at; ranks[--indices[roots[at] - roots[at - 1]]] = at);
for (; at < (seen - 1); at++)
for (others = roots[ranks[at] - 1]; ++others <= roots[ranks[at]]; values[order[others]] = at);
static inline unsigned nearest_common_ancestor(
unsigned depth,
unsigned base_cnt,
unsigned vertex_cnt,
unsigned base_ids[vertex_cnt],
unsigned bases[base_cnt][depth],
unsigned char depths[base_cnt],
unsigned weights[vertex_cnt],
unsigned lower,
unsigned upper
) {
if (upper < (lower + weights[lower]))
return lower;
if (depths[upper] > depths[lower])
upper = bases[base_ids[upper]][depths[upper] - depths[lower] - 1];
if (upper < lower)
return upper;
unsigned *others = bases[base_ids[upper]];
for (; depth > 1; depth >>= 1)
if (others[depth >> 1] > lower) {
others += depth >> 1;
depth += depth & 1U;
return others[others[0] > lower];
typedef union {
unsigned long packd;
struct {
int low, high;
} range_t;
typedef struct {
} colored_tree_t;
unsigned long query_all(colored_tree_t *self, unsigned at, range_t other) {
color = self->colors[at],
length = self->indices[color + 1] - self->indices[color],
*base = &self->members[self->indices[color]];
if (other.high < base[0] || other.low > base[length - 1])
return 0;
if (self->colors[other.low] != color) {
if (at < other.low) {
base += self->locations[at] - self->indices[color];
length = self->indices[color + 1] - self->locations[at];
} else
length = self->locations[at] - self->indices[color]; // at > other.low
for (; length > 1; length >>= 1)
if (base[length >> 1] < other.low) {
base += length >> 1;
length += length & 1;
base += (base[0] < other.low);
} else
base += (self->locations[other.low] - self->indices[color]);
if (base[0] > other.high)
return 0;
unsigned *ceil;
if (self->colors[other.high] != color) {
ceil = (at > base[0] && at < other.high) ? &self->members[self->locations[at]] : base;
for (length = self->indices[color + 1] - self->locations[ceil[0]]; length > 1; length >>= 1)
if (ceil[length >> 1] <= other.high) {
ceil += length >> 1;
length += length & 1;
ceil -= (ceil[0] > other.high);
} else
ceil = &self->members[self->locations[other.high]];
return ceil - base + 1 - (at >= other.low && at <= other.high);
unsigned long count_pairs(
unsigned cnt,
unsigned length,
unsigned long pairs[cnt][cnt],
unsigned *overlapping,
colored_tree_t *tree,
range_t self,
range_t other
) {
unsigned long count = 0;
for (; (self.low % length) && (self.low <= self.high); count += query_all(tree, self.low++, other));
for (; ((self.high + 1) % length) && (self.low <= self.high); count += query_all(tree, self.high--, other));
if (self.low <= self.high) {
for (; (other.low % length) && (other.low <= other.high); count += query_all(tree, other.low++, self));
for (; ((other.high + 1) % length) && (other.low <= other.high); count += query_all(tree, other.high--, self));
if (other.low <= other.high) {
self.low /= length;
self.high /= length;
other.low /= length;
other.high /= length;
if (self.low > other.low) {
self.packd ^= other.packd;
other.packd ^= self.packd;
self.packd ^= other.packd;
unsigned high = (self.high < other.low) ? self.high : (other.low - 1);
count +=
- pairs[high][other.low - 1UL]
- pairs[self.low - 1UL][other.high]
+ pairs[self.low - 1UL][other.low - 1UL];
self.low = high + 1;
if (self.high > other.high) {
self.packd ^= other.packd;
other.packd ^= self.packd;
self.packd ^= other.packd;
if (self.low <= self.high)
count +=
(overlapping[self.high] - overlapping[self.low - 1UL])
+ ((
- pairs[self.high][self.low - 1UL]
- pairs[self.low - 1UL][self.high]
+ pairs[self.low - 1UL][self.low - 1UL]
) << 1) + (
- pairs[self.high][self.high]
- pairs[self.low - 1UL][other.high]
+ pairs[self.low - 1UL][self.high]
return count;
int main() {
unsigned at, vertex_cnt;
unsigned short query_cnt;
scanf("%u %hu", &vertex_cnt, &query_cnt);
unsigned colors[vertex_cnt + 1];
for (at = 0; at < vertex_cnt; scanf("%u", &colors[at++]));
colors[at] = 0xFFFFFFFFU;
compress(at + 1, colors);
unsigned ancestors[at + 1];
unsigned ancestor, descendant;
for (memset(ancestors, 0xFFU, sizeof(ancestors)); --at; ancestors[descendant] = ancestor) {
scanf("%u %u", &ancestor, &descendant);
if (ancestors[--descendant] != 0xFFFFFFFFU) {
unsigned root = descendant, next;
for (; ancestor != 0xFFFFFFFFU; ancestor = next) {
next = ancestors[ancestor];
ancestors[ancestor] = root;
root = ancestor;
for (; ancestors[descendant] != 0xFFFFFFFFU; descendant = next) {
next = ancestors[descendant];
ancestors[descendant] = ancestor;
ancestor = descendant;
for (ancestor = 0xFFFFFFFFU; at != 0xFFFFFFFFU; at = descendant) {
descendant = ancestors[at];
ancestors[at] = ancestor;
ancestor = at;
ids[vertex_cnt + 1],
bases[vertex_cnt + 1],
unsigned char
dist = 0;
indices[vertex_cnt + 1],
memset(indices, 0, sizeof(indices));
for (ancestors[vertex_cnt] = (at = vertex_cnt); at; indices[ancestors[at--]]++);
for (; ++at <= vertex_cnt; indices[at] += indices[at - 1]);
for (; --at; descendants[--indices[ancestors[at]]] = at);
history[0] = 0;
memset(weights, 0, sizeof(weights));
for (at = 1; at--; )
if (weights[history[at]])
for (others = indices[history[at]];
others < indices[history[at] + 1];
weights[history[at]] += weights[descendants[others++]]);
else {
weights[history[at]] = 1;
&history[at + 1],
(indices[history[at] + 1] - indices[history[at]]) * sizeof(descendants[0])
at += indices[history[at] + 1] - indices[history[at]] + 1;
orig_ancestors[vertex_cnt + 1],
orig_colors[vertex_cnt + 1],
memcpy(orig_ancestors, ancestors, sizeof(ancestors));
memcpy(orig_weights, weights, sizeof(weights));
memcpy(orig_colors, colors, sizeof(colors));
base_depths[0] = (bases[0] = (ids[0] = 0));
bases[vertex_cnt] = (ids[vertex_cnt] = vertex_cnt);
for (at = 1; at--;) {
id = ids[history[at]],
base = bases[id++],
branches = indices[history[at] + 1] - indices[history[at]];
heap_sort(&descendants[indices[history[at]]], orig_weights, branches);
memcpy(&history[at], &descendants[indices[history[at]]], branches * sizeof(descendants[0]));
for (others = (at += branches); branches--; base = id) {
ids[history[--others]] = id;
ancestors[id] = ids[orig_ancestors[history[others]]];
weights[id] = orig_weights[history[others]];
colors[id] = orig_colors[history[others]];
bases[id] = base;
base_depths[id] = base_depths[ancestors[id]] + (base == id);
if (dist < base_depths[id])
dist = base_depths[id];
id += weights[id];
unsigned base_ids[vertex_cnt + 1];
for (base_ids[0] = (others = (at = 0)); others < vertex_cnt; base_ids[others] = base_ids[at] + 1)
for (at = others; bases[at] == bases[others]; base_ids[others++] = base_ids[at]);
unsigned ancestral_bases[base_ids[vertex_cnt]][dist];
for (ancestors[0] = 0; others--; ancestral_bases[base_ids[others]][0] = ancestors[others]);
while ((++others + 1) < dist)
for (at = 0; ++at < base_ids[vertex_cnt];
ancestral_bases[at][others + 1] = ancestors[bases[ancestral_bases[at][others]]]);
indexed_colors[colors[vertex_cnt] + 2],
members[vertex_cnt + 1];
memset(indexed_colors, 0, sizeof(indexed_colors));
for (at = vertex_cnt + 1; at--; indexed_colors[colors[at]]++);
for (; ++at < colors[vertex_cnt]; indexed_colors[at + 1] += indexed_colors[at]);
for (at = vertex_cnt + 1; at--; members[--indexed_colors[colors[at]]] = at);
indexed_colors[colors[vertex_cnt] + 1] = indexed_colors[colors[vertex_cnt]];
levels = floor_log2(vertex_cnt) + 1,
block_cnt = (vertex_cnt / levels) + 1,
locations[vertex_cnt + 1],
unsigned long (*pairs)[block_cnt][block_cnt] = calloc(
(1 + block_cnt) * (1 + block_cnt),
pairs = (void *)&pairs[0][1][1];
for (at = vertex_cnt + 1; at--; locations[members[at]] = at);
memset(overlapping, 0, sizeof(overlapping));
for (at = 0; (indexed_colors[at + 1] - indexed_colors[at]) > 1; at++) {
others = indexed_colors[at];
block_bases[indexed_colors[at + 1] - others + 1],
cnt = 1;
for (block_bases[0] = members[others]; at == colors[members[++others]]; )
if ((members[others] / levels) != (block_bases[cnt - 1] / levels))
block_bases[cnt++] = members[others];
block_bases[cnt] = members[others];
for (others = 0; others < cnt; others++) {
unsigned long density = locations[block_bases[others + 1]] - locations[block_bases[others]];
overlapping[block_bases[others] / levels] += density * (density - 1);
unsigned block = others;
for (; ++block < cnt; pairs[0][block_bases[others] / levels][block_bases[block] / levels]
+= density * (locations[block_bases[block + 1]] - locations[block_bases[block]]));
for (at = 0; ++at < block_cnt; overlapping[at] += overlapping[at - 1])
pairs[0][0][at] += pairs[0][0][at - 1];
for (at = 0; ++at < block_cnt; )
for (others = 0; ++others < block_cnt; pairs[0][at][others] += pairs[0][at][others - 1]);
for (at = 0; ++at < block_cnt; )
for (others = 0; others < block_cnt; others++)
pairs[0][at][others] += pairs[0][at - 1][others];
colored_tree_t *tree = &(colored_tree_t) {
.members = members,
.colors = colors,
.indices = indexed_colors,
.locations = locations
while (query_cnt--) {
range_t left, right;
scanf("%u %u %u %u", &left.low, &left.high, &right.low, &right.high);
left.packd -= 0x0000000100000001UL;
right.packd -= 0x0000000100000001UL;
left.low = ids[left.low];
left.high = ids[left.high];
right.low = ids[right.low];
right.high = ids[right.high];
if (left.high < left.low)
left.packd = (left.packd << 32) | (left.packd >> 32);
if (right.high < right.low)
right.packd = (right.packd << 32) | (right.packd >> 32);
if (right.high < left.low) {
left.packd ^= right.packd;
right.packd ^= left.packd;
left.packd ^= right.packd;
struct {
range_t members[32];
unsigned cnt;
a = {.cnt = 0},
b = {.cnt = 0};
unsigned common = nearest_common_ancestor(
dist, base_ids[vertex_cnt], vertex_cnt,
base_ids, ancestral_bases,
base_depths, weights,
left.low, left.high
for (at = left.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
a.members[a.cnt++].packd = bases[at] | ((unsigned long)at << 32);
for (others = left.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
a.members[a.cnt++].packd = bases[others] | ((unsigned long)others << 32);
a.members[a.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);
common = nearest_common_ancestor(
dist, base_ids[vertex_cnt], vertex_cnt,
base_ids, ancestral_bases,
base_depths, weights,
right.low, right.high
for (at = right.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
b.members[b.cnt++].packd = bases[at] | ((unsigned long)at << 32);
for (others = right.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
b.members[b.cnt++].packd = bases[others] | ((unsigned long)others << 32);
b.members[b.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);
unsigned long total = 0;
for (at = 0; at < a.cnt; at++)
for (others = 0; others < b.cnt;
total += count_pairs(
block_cnt, levels, pairs[0], overlapping, tree,
a.members[at], b.members[others++]
printf("%lu\n", total);
return 0;
In Python3 :
import sys
from filecmp import cmp
from os import linesep
from time import time
from collections import Counter
memoized_BIT_prevs = [list() for _ in range(10**5+1)]
for i in range(1,10**5+1):
next = i + (i & -i)
if next >= len(memoized_BIT_prevs):
class UF(object):
def __init__(self):
self.uf = {} #keeps union-find structure
self.ranks = {}
def UFadd(self,a):
self.uf[a] = a #add curr to union-find
self.ranks[a] = 0
def UFfind(self,a):
uf = self.uf
curr = a
while curr != uf[curr]:
next = uf[uf[curr]] #does not fully path compress
curr = next
return curr
def UFcombine(self,a,b):
uf = self.uf
ranks = self.ranks
a_top = self.UFfind(a)
b_top = self.UFfind(b)
rank_a = ranks[a_top]
rank_b = ranks[b_top]
if rank_a < rank_b:
uf[a_top] = b_top
elif rank_a > rank_b:
uf[b_top] = a_top
uf[b_top] = a_top
ranks[a_top] += 1
class mycounter(object):
__slots__ = ['c']
def __init__(self,key = None):
if key is None:
self.c = {}
self.c = {key:1}
def inc(self,key):
self.c[key] = self.c.get(key,0) + 1
def addto(self,other):
for key,otherval in other.c.items():
self.c[key] = self.c.get(key,0) + otherval
def subfrom(self,other,mult = 1): #subtract other from self, never neg
for key,otherval in other.c.items():
assert key in self.c
self.c[key] = self.c[key] - mult*otherval
def innerprodwith(self,other):
X = self.c
Y = other.c
if len(X) > len(Y):
X,Y = Y,X
return sum(X[i]*Y[i] for i in X if i in Y)
class mycounter2(object):
__slots__ = ['q','doubleneg']
def __init__(self):
self.q = []
self.doubleneg = None #will be list
def inc(self,key):
def addto(self,other):
def subfrom(self,other,mult = 1): #subtract other from self, never neg
assert self.doubleneg is None and other.doubleneg is None
self.doubleneg = other.q
def innerprodwith(self,other):
X = self
Y = other
if len(X.q) > len(Y.q):
X,Y = Y,X
Xq = X.q
Yq = Y.q
Xn = X.doubleneg
Yn = Y.doubleneg
S = set(Xq)
return sum((Xq.count(s)-2*Xn.count(s))*(Yq.count(s)-2*Yn.count(s)) for s in S)
class geodcounter(object):
__slots__ = ['size','C','above','below','values','lca','dists']
def __init__(self,neb,root,vals,pairs):
self.size = len(neb)
self.C = [mycounter() for _ in range(self.size)]
self.above = [None]*self.size
self.below = [list() for _ in range(self.size)]
self.values = vals
self.lca = {}
self.dists = [None]*self.size
above = self.above
below = self.below
lca = self.lca
dists = self.dists
geod = [None] #one-indexed to fit BIT
gpush = geod.append
gpop = geod.pop
height = 0
traverse_stack = [(root,None)]
tpush = traverse_stack.append
tpop = traverse_stack.pop
visited = set()
ancestors = {}
lca_done = set()
uf = UF()
curr,parent = tpop()
if curr is None:
if curr not in visited:
dists[curr] = height
height += 1
aboveht = height - (height & -height)
above[curr] = geod[aboveht]
for nextht in memoized_BIT_prevs[height]:
ancestors[curr] = curr
for child in neb[curr]:
if child in visited:
gcurr = gpop()
assert gcurr == curr
# tho self.below not complete yet,
# it is for subtree rooted at curr, so OK to call notice
height -= 1
#this portion implements Tarjan's LCA alg
for v in pairs[curr]:
if v in lca_done:
rep_v = uf.UFfind(v)
anc = ancestors[rep_v]
assert (curr,v) not in lca and (v,curr) not in lca
lca[(curr,v)] = anc
lca[(v,curr)] = anc
if parent is not None:
ancestors[uf.UFfind(curr)] = parent
def notice(self,node):
val = self.values[node]
C = self.C
below = self.below
stack = [node]
pop = stack.pop
extend = stack.extend
while stack:
curr = pop()
extend(below[curr]) #note that the ORDER we visit them doesn't matter
def find(self,node):
acc = mycounter()
C = self.C
above = self.above
while node is not None:
node = above[node]
return acc
def get_incidence(self,a,b):
l = self.lca[(a,b)]
find_a = self.find(a)
#find_b = self.find(b)
#delta = mycounter(self.values[l])
#pos_part = find_a + find_b + delta
#answer = pos_part - find_l - find_l
#order matters b/c of how subtraction of counters works
return find_a
def size_intersection(self,a,b,c,d):
prs = ((a,b),(c,d),(a,c),(a,d),(b,c),(b,d))
verts = list(self.lca[x] for x in prs)
key1,key2 = verts[0],verts[1]
check = Counter(verts)
if check[key1] == 1 or check[key2] == 1:
return 0 #empty intersection
most = check.most_common()
counts = list(freq for val,freq in most)
if counts == [6] or counts == [3,3]:
return 1 #intersect in one point
elif counts == [5,1] or counts == [3,2,1]: #root meets at or beyond endpt of intersection
close = most[-2][0]
far = most[-1][0]
return (self.dists[far] - self.dists[close])+1 #size, not length
elif counts == [4,1,1]:
left = most[1][0]
right = most[2][0]
mid = most[0][0]
return (self.dists[left] + self.dists[right] - 2*self.dists[mid])+ 1 #size, not length
raise RuntimeError
def process_commands(self,commands):
measure_int = self.size_intersection
get_incidence = self.get_incidence
timer = 0
for w,x,y,z in commands:
A = get_incidence(w,x)
B = get_incidence(y,z)
innerprod = A.innerprodwith(B)
len_inter = measure_int(w,x,y,z)
ans = innerprod - len_inter
assert ans >= 0
yield ans
def test(num = None):
if num is None:
inp = sys.stdin
out = sys.stdout
inp = open('./input'+num+'.txt')
out = open('./myoutput'+num+'.txt','w')
start_time = time()
N,Q = tuple(map(int,inp.readline().strip().split(' ')))
vals = [0] #one-indexed so c_i = C[i]
vals.extend(map(int,inp.readline().strip().split(' ')))
neb = [list() for x in range(N+1)]
for _ in range(N-1):
a, b = tuple(map(int,inp.readline().strip().split(' ')))
pairs = [set() for _ in range(N+1)]
commands = []
for _ in range(Q):
w,x,y,z = tuple(map(int,inp.readline().strip().split(' ')))
count_geod = geodcounter(neb,1,vals,pairs) #make 1 root arbitrarily
for answer in count_geod.process_commands(commands):#,desireds):
end_time = time()
if num is not None:
if num != '00':
remove_chars = len(linesep)
out.truncate(out.tell() - remove_chars) #strip trailing newline
succeeded = cmp('myoutput'+num+'.txt','output'+num+'.txt')
outcome = " Success" if succeeded else " Failure"
print("#"+num+outcome,"in {:f}s".format(end_time-start_time))
