Counting On a Tree
Problem Statement :
Taylor loves trees, and this new challenge has him stumped! Consider a tree, t, consisting of n nodes. Each node is numbered from 1 to n, and each node i has an integer, ci, attached to it. A query on tree t takes the form w x y z. To process a query, you must print the count of ordered pairs of integers ( i , j ) such that the following four conditions are all satisfied: the path from node to node . path from node to node . Given t and q queries, process each query in order, printing the pair count for each query on a new line. Input Format The first line contains two space-separated integers describing the respective values of n (the number of nodes) and q (the number of queries). The second line contains space-separated integers describing the respective values of each node (i.e., c1 , c2, . . . , cn ). Each of the n - 1 subsequent lines contains two space-separated integers, u and v , defining a bidirectional edge between nodes u and v. Each of the q subsequent lines contains a w x y z query, defined above. Constraints 1 <= n <= 10^5 1 <= q <= 50000 1 <= ci <= 10^9 1 <= u, v , w, x, y , z <= n Scoring for this problem is Binary, that means you have to pass all the test cases to get a positive score. Output Format For each query, print the count of ordered pairs of integers satisfying the four given conditions on a new line.
Solution :
Solution in C :
In C++ :
#include <bits/stdc++.h>
using namespace std;
const int kMaxN = 300005;
vector<int> G[kMaxN];
int Color[kMaxN], Linear[kMaxN], Beg[kMaxN],
End[kMaxN], Depth[kMaxN];
long long Answer[kMaxN];
int n, colors, timer;
int Parent[20][kMaxN];
void DFS(int node) {
Beg[node] = ++timer;
Linear[timer] = +Color[node];
for(auto vec : G[node]) {
if(!Beg[vec]) {
Parent[0][vec] = node;
Depth[vec] = Depth[node] + 1;
for(int i = 1; Parent[i - 1][vec]; ++i) {
Parent[i][vec] = Parent[i - 1][Parent[i - 1][vec]];
}
DFS(vec);
}
}
End[node] = ++timer;
Linear[timer] = -Color[node];
}
int LCA(int a, int b) {
if(Depth[a] < Depth[b])
swap(a, b);
int d = Depth[a] - Depth[b];
for(int i = 0; d; ++i, d /= 2)
if(d % 2)
a = Parent[i][a];
if(a == b)
return a;
for(int i = 19; i >= 0; --i)
if(Parent[i][a] != Parent[i][b])
a = Parent[i][a], b = Parent[i][b];
assert(Parent[0][a] == Parent[0][b]);
return Parent[0][a];
}
struct Query {
int a, b, i, sgn;
};
vector<Query> Queries;
void AddQuery(int a, int b, int i, int sgn) {
if(a == 0 || b == 0) return;
Queries.push_back(Query {a, b, i, sgn});
//cerr << "Added: " << a << " " << b << " at ind " << i << " with sign " << sgn << '\n';
}
void AddQuery(int a, int b, int c, int d, int i) {
AddQuery(a, c, i, 1);
AddQuery(b, c, i, -1);
AddQuery(a, d, i, -1);
AddQuery(b, d, i, 1);
}
int Count[2][kMaxN];
long long global_ans;
void Add(int at, int col) {
int pos = abs(col);
int delta = col / pos;
global_ans -= 1LL * Count[0][pos] * Count[1][pos];
Count[at][pos] += delta;
global_ans += 1LL * Count[0][pos] * Count[1][pos];
}
void SolveQueries() {
for(auto &q : Queries) {
Answer[q.i] -= q.sgn * (1 + Depth[LCA(q.a, q.b)]);
q.a = Beg[q.a];
q.b = Beg[q.b];
if(q.a > q.b)
swap(q.a, q.b);
}
const int kMagic = 256;
sort(Queries.begin(), Queries.end(), [](Query &a, Query &b) {
if(a.a / kMagic == b.a / kMagic)
return a.b < b.b;
return a.a < b.a;
});
int b = 0, e = 0;
for(auto &q : Queries) {
while(b < q.a) Add(0, Linear[++b]);
while(e < q.b) Add(1, Linear[++e]);
while(b > q.a) Add(0, -Linear[b--]);
while(e > q.b) Add(1, -Linear[e--]);
Answer[q.i] += global_ans * q.sgn;
}
}
int main() {
int q;
cin >> n >> q;
map<int, int> normMap;
for(int i = 1; i <= n; ++i) {
cin >> Color[i];
auto &norm = normMap[Color[i]];
if(norm == 0)
norm = ++colors;
Color[i] = norm;
}
for(int i = 2; i <= n; ++i) {
int a, b;
cin >> a >> b;
G[a].push_back(b);
G[b].push_back(a);
}
DFS(1);
Depth[0] = -1;
/*
for(int i = 1; i <= timer; ++i)
cerr << Linear[i] << " ";
cerr << endl;
*/
for(int i = 1; i <= q; ++i) {
int a1, b1, a2, b2;
cin >> a1 >> b1 >> a2 >> b2;
int l1 = LCA(a1, b1);
int l2 = LCA(a2, b2);
AddQuery(a1, Parent[0][l1], a2, Parent[0][l2], i);
AddQuery(a1, Parent[0][l1], b2, l2, i);
AddQuery(b1, l1, a2, Parent[0][l2], i);
AddQuery(b1, l1, b2, l2, i);
}
SolveQueries();
for(int i = 1; i <= q; ++i)
cout << Answer[i] << "\n";
cout << endl;
return 0;
}
In Java :
import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;
public class Solution {
// Complete the solve function below.
static int[] solve(int[] c, int[][] tree, int[][] queries) {
}
private static final Scanner scanner = new Scanner(System.in);
public static void main(String[] args) throws IOException {
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));
String[] nq = scanner.nextLine().split(" ");
int n = Integer.parseInt(nq[0]);
int q = Integer.parseInt(nq[1]);
int[] c = new int[n];
String[] cItems = scanner.nextLine().split(" ");
scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");
for (int cItr = 0; cItr < n; cItr++) {
int cItem = Integer.parseInt(cItems[cItr]);
c[cItr] = cItem;
}
int[][] tree = new int[n-1][2];
for (int treeRowItr = 0; treeRowItr < n-1; treeRowItr++) {
String[] treeRowItems = scanner.nextLine().split(" ");
scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");
for (int treeColumnItr = 0; treeColumnItr < 2; treeColumnItr++) {
int treeItem = Integer.parseInt(treeRowItems[treeColumnItr]);
tree[treeRowItr][treeColumnItr] = treeItem;
}
}
int[][] queries = new int[q][4];
for (int queriesRowItr = 0; queriesRowItr < q; queriesRowItr++) {
String[] queriesRowItems = scanner.nextLine().split(" ");
scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");
for (int queriesColumnItr = 0; queriesColumnItr < 4; queriesColumnItr++) {
int queriesItem = Integer.parseInt(queriesRowItems[queriesColumnItr]);
queries[queriesRowItr][queriesColumnItr] = queriesItem;
}
}
int[] result = solve(c, tree, queries);
for (int resultItr = 0; resultItr < result.length; resultItr++) {
bufferedWriter.write(String.valueOf(result[resultItr]));
if (resultItr != result.length - 1) {
bufferedWriter.write("\n");
}
}
bufferedWriter.newLine();
bufferedWriter.close();
scanner.close();
}
}
In C :
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>
#define floor_log2_X86(self) (__builtin_clz(self) ^ 31U)
#define floor_log2 floor_log2_X86
void heap_sort(unsigned *self, unsigned *weights, unsigned length) {
unsigned
at = length >> 1,
member,
node;
for (self--; at; self[node >> 1] = member) {
member = self[at];
for (node = at-- << 1; node <= length; node <<= 1) {
node |= (node < length) && (weights[self[node]] < weights[self[node | 1]]);
if (weights[self[node]] < weights[member])
break ;
self[node >> 1] = self[node];
}
}
for (; length > 1; self[at >> 1] = member) {
member = self[length];
self[length--] = self[1];
for (at = 2; at <= length; at <<= 1) {
at |= (at < length) && (weights[self[at]] < weights[self[at | 1]]);
if (weights[self[at]] < weights[member])
break ;
self[at >> 1] = self[at];
}
}
}
void compress(unsigned length, unsigned values[length]) {
unsigned
at,
order[length];
unsigned long sum = 0x0000000100000000UL;
for (at = 0; at < (length >> 1); sum += 0x0000000200000002UL)
((unsigned long *)order)[at++] = sum;
order[length - 1] = length - 1;
heap_sort(order, values, length);
unsigned roots[length], seen = 1, max = 0, others;
for (roots[at = 0] = -1U; at < length; roots[seen++] = at - 1) {
for (others = at; (at < length) && values[order[at]] == values[order[others]]; at++);
if (max < (at - others))
max = (at - others);
}
unsigned
indices[max + 1],
ranks[seen];
memset(indices, 0, sizeof(indices));
for (at = 0; ++at < seen; indices[roots[at] - roots[at - 1]]++);
for (at = max; at--; indices[at] += indices[at + 1]);
for (at = seen; --at; ranks[--indices[roots[at] - roots[at - 1]]] = at);
for (; at < (seen - 1); at++)
for (others = roots[ranks[at] - 1]; ++others <= roots[ranks[at]]; values[order[others]] = at);
}
static inline unsigned nearest_common_ancestor(
unsigned depth,
unsigned base_cnt,
unsigned vertex_cnt,
unsigned base_ids[vertex_cnt],
unsigned bases[base_cnt][depth],
unsigned char depths[base_cnt],
unsigned weights[vertex_cnt],
unsigned lower,
unsigned upper
) {
if (upper < (lower + weights[lower]))
return lower;
if (depths[upper] > depths[lower])
upper = bases[base_ids[upper]][depths[upper] - depths[lower] - 1];
if (upper < lower)
return upper;
unsigned *others = bases[base_ids[upper]];
for (; depth > 1; depth >>= 1)
if (others[depth >> 1] > lower) {
others += depth >> 1;
depth += depth & 1U;
}
return others[others[0] > lower];
}
typedef union {
unsigned long packd;
struct {
int low, high;
};
} range_t;
typedef struct {
unsigned
*members,
*colors,
*indices,
*locations;
} colored_tree_t;
unsigned long query_all(colored_tree_t *self, unsigned at, range_t other) {
unsigned
color = self->colors[at],
length = self->indices[color + 1] - self->indices[color],
*base = &self->members[self->indices[color]];
if (other.high < base[0] || other.low > base[length - 1])
return 0;
if (self->colors[other.low] != color) {
if (at < other.low) {
base += self->locations[at] - self->indices[color];
length = self->indices[color + 1] - self->locations[at];
} else
length = self->locations[at] - self->indices[color]; // at > other.low
for (; length > 1; length >>= 1)
if (base[length >> 1] < other.low) {
base += length >> 1;
length += length & 1;
}
base += (base[0] < other.low);
} else
base += (self->locations[other.low] - self->indices[color]);
if (base[0] > other.high)
return 0;
unsigned *ceil;
if (self->colors[other.high] != color) {
ceil = (at > base[0] && at < other.high) ? &self->members[self->locations[at]] : base;
for (length = self->indices[color + 1] - self->locations[ceil[0]]; length > 1; length >>= 1)
if (ceil[length >> 1] <= other.high) {
ceil += length >> 1;
length += length & 1;
}
ceil -= (ceil[0] > other.high);
} else
ceil = &self->members[self->locations[other.high]];
return ceil - base + 1 - (at >= other.low && at <= other.high);
}
unsigned long count_pairs(
unsigned cnt,
unsigned length,
unsigned long pairs[cnt][cnt],
unsigned *overlapping,
colored_tree_t *tree,
range_t self,
range_t other
) {
unsigned long count = 0;
for (; (self.low % length) && (self.low <= self.high); count += query_all(tree, self.low++, other));
for (; ((self.high + 1) % length) && (self.low <= self.high); count += query_all(tree, self.high--, other));
if (self.low <= self.high) {
for (; (other.low % length) && (other.low <= other.high); count += query_all(tree, other.low++, self));
for (; ((other.high + 1) % length) && (other.low <= other.high); count += query_all(tree, other.high--, self));
if (other.low <= other.high) {
self.low /= length;
self.high /= length;
other.low /= length;
other.high /= length;
if (self.low > other.low) {
self.packd ^= other.packd;
other.packd ^= self.packd;
self.packd ^= other.packd;
}
unsigned high = (self.high < other.low) ? self.high : (other.low - 1);
count +=
pairs[high][other.high]
- pairs[high][other.low - 1UL]
- pairs[self.low - 1UL][other.high]
+ pairs[self.low - 1UL][other.low - 1UL];
self.low = high + 1;
if (self.high > other.high) {
self.packd ^= other.packd;
other.packd ^= self.packd;
self.packd ^= other.packd;
}
if (self.low <= self.high)
count +=
(overlapping[self.high] - overlapping[self.low - 1UL])
+ ((
pairs[self.high][self.high]
- pairs[self.high][self.low - 1UL]
- pairs[self.low - 1UL][self.high]
+ pairs[self.low - 1UL][self.low - 1UL]
) << 1) + (
pairs[self.high][other.high]
- pairs[self.high][self.high]
- pairs[self.low - 1UL][other.high]
+ pairs[self.low - 1UL][self.high]
);
}
}
return count;
}
int main() {
unsigned at, vertex_cnt;
unsigned short query_cnt;
scanf("%u %hu", &vertex_cnt, &query_cnt);
unsigned colors[vertex_cnt + 1];
for (at = 0; at < vertex_cnt; scanf("%u", &colors[at++]));
colors[at] = 0xFFFFFFFFU;
compress(at + 1, colors);
unsigned ancestors[at + 1];
{
unsigned ancestor, descendant;
for (memset(ancestors, 0xFFU, sizeof(ancestors)); --at; ancestors[descendant] = ancestor) {
scanf("%u %u", &ancestor, &descendant);
--ancestor;
if (ancestors[--descendant] != 0xFFFFFFFFU) {
unsigned root = descendant, next;
for (; ancestor != 0xFFFFFFFFU; ancestor = next) {
next = ancestors[ancestor];
ancestors[ancestor] = root;
root = ancestor;
}
for (; ancestors[descendant] != 0xFFFFFFFFU; descendant = next) {
next = ancestors[descendant];
ancestors[descendant] = ancestor;
ancestor = descendant;
}
}
}
for (ancestor = 0xFFFFFFFFU; at != 0xFFFFFFFFU; at = descendant) {
descendant = ancestors[at];
ancestors[at] = ancestor;
ancestor = at;
}
}
unsigned
others,
ids[vertex_cnt + 1],
weights[vertex_cnt],
bases[vertex_cnt + 1],
history[vertex_cnt];
unsigned char
base_depths[vertex_cnt],
dist = 0;
{
unsigned
history[vertex_cnt],
indices[vertex_cnt + 1],
descendants[vertex_cnt];
memset(indices, 0, sizeof(indices));
for (ancestors[vertex_cnt] = (at = vertex_cnt); at; indices[ancestors[at--]]++);
for (; ++at <= vertex_cnt; indices[at] += indices[at - 1]);
for (; --at; descendants[--indices[ancestors[at]]] = at);
history[0] = 0;
memset(weights, 0, sizeof(weights));
for (at = 1; at--; )
if (weights[history[at]])
for (others = indices[history[at]];
others < indices[history[at] + 1];
weights[history[at]] += weights[descendants[others++]]);
else {
weights[history[at]] = 1;
memcpy(
&history[at + 1],
&descendants[indices[history[at]]],
(indices[history[at] + 1] - indices[history[at]]) * sizeof(descendants[0])
);
at += indices[history[at] + 1] - indices[history[at]] + 1;
}
unsigned
orig_ancestors[vertex_cnt + 1],
orig_colors[vertex_cnt + 1],
orig_weights[vertex_cnt];
memcpy(orig_ancestors, ancestors, sizeof(ancestors));
memcpy(orig_weights, weights, sizeof(weights));
memcpy(orig_colors, colors, sizeof(colors));
base_depths[0] = (bases[0] = (ids[0] = 0));
bases[vertex_cnt] = (ids[vertex_cnt] = vertex_cnt);
for (at = 1; at--;) {
unsigned
id = ids[history[at]],
base = bases[id++],
branches = indices[history[at] + 1] - indices[history[at]];
heap_sort(&descendants[indices[history[at]]], orig_weights, branches);
memcpy(&history[at], &descendants[indices[history[at]]], branches * sizeof(descendants[0]));
for (others = (at += branches); branches--; base = id) {
ids[history[--others]] = id;
ancestors[id] = ids[orig_ancestors[history[others]]];
weights[id] = orig_weights[history[others]];
colors[id] = orig_colors[history[others]];
bases[id] = base;
base_depths[id] = base_depths[ancestors[id]] + (base == id);
if (dist < base_depths[id])
dist = base_depths[id];
id += weights[id];
}
}
}
unsigned base_ids[vertex_cnt + 1];
for (base_ids[0] = (others = (at = 0)); others < vertex_cnt; base_ids[others] = base_ids[at] + 1)
for (at = others; bases[at] == bases[others]; base_ids[others++] = base_ids[at]);
unsigned ancestral_bases[base_ids[vertex_cnt]][dist];
for (ancestors[0] = 0; others--; ancestral_bases[base_ids[others]][0] = ancestors[others]);
while ((++others + 1) < dist)
for (at = 0; ++at < base_ids[vertex_cnt];
ancestral_bases[at][others + 1] = ancestors[bases[ancestral_bases[at][others]]]);
unsigned
indexed_colors[colors[vertex_cnt] + 2],
members[vertex_cnt + 1];
memset(indexed_colors, 0, sizeof(indexed_colors));
for (at = vertex_cnt + 1; at--; indexed_colors[colors[at]]++);
for (; ++at < colors[vertex_cnt]; indexed_colors[at + 1] += indexed_colors[at]);
for (at = vertex_cnt + 1; at--; members[--indexed_colors[colors[at]]] = at);
indexed_colors[colors[vertex_cnt] + 1] = indexed_colors[colors[vertex_cnt]];
unsigned
levels = floor_log2(vertex_cnt) + 1,
block_cnt = (vertex_cnt / levels) + 1,
locations[vertex_cnt + 1],
overlapping[block_cnt];
unsigned long (*pairs)[block_cnt][block_cnt] = calloc(
(1 + block_cnt) * (1 + block_cnt),
sizeof(pairs[0][0][0])
);
pairs = (void *)&pairs[0][1][1];
for (at = vertex_cnt + 1; at--; locations[members[at]] = at);
memset(overlapping, 0, sizeof(overlapping));
for (at = 0; (indexed_colors[at + 1] - indexed_colors[at]) > 1; at++) {
others = indexed_colors[at];
unsigned
block_bases[indexed_colors[at + 1] - others + 1],
cnt = 1;
for (block_bases[0] = members[others]; at == colors[members[++others]]; )
if ((members[others] / levels) != (block_bases[cnt - 1] / levels))
block_bases[cnt++] = members[others];
block_bases[cnt] = members[others];
for (others = 0; others < cnt; others++) {
unsigned long density = locations[block_bases[others + 1]] - locations[block_bases[others]];
overlapping[block_bases[others] / levels] += density * (density - 1);
unsigned block = others;
for (; ++block < cnt; pairs[0][block_bases[others] / levels][block_bases[block] / levels]
+= density * (locations[block_bases[block + 1]] - locations[block_bases[block]]));
}
}
for (at = 0; ++at < block_cnt; overlapping[at] += overlapping[at - 1])
pairs[0][0][at] += pairs[0][0][at - 1];
for (at = 0; ++at < block_cnt; )
for (others = 0; ++others < block_cnt; pairs[0][at][others] += pairs[0][at][others - 1]);
for (at = 0; ++at < block_cnt; )
for (others = 0; others < block_cnt; others++)
pairs[0][at][others] += pairs[0][at - 1][others];
colored_tree_t *tree = &(colored_tree_t) {
.members = members,
.colors = colors,
.indices = indexed_colors,
.locations = locations
};
while (query_cnt--) {
range_t left, right;
scanf("%u %u %u %u", &left.low, &left.high, &right.low, &right.high);
left.packd -= 0x0000000100000001UL;
right.packd -= 0x0000000100000001UL;
left.low = ids[left.low];
left.high = ids[left.high];
right.low = ids[right.low];
right.high = ids[right.high];
if (left.high < left.low)
left.packd = (left.packd << 32) | (left.packd >> 32);
if (right.high < right.low)
right.packd = (right.packd << 32) | (right.packd >> 32);
if (right.high < left.low) {
left.packd ^= right.packd;
right.packd ^= left.packd;
left.packd ^= right.packd;
}
struct {
range_t members[32];
unsigned cnt;
}
a = {.cnt = 0},
b = {.cnt = 0};
unsigned common = nearest_common_ancestor(
dist, base_ids[vertex_cnt], vertex_cnt,
base_ids, ancestral_bases,
base_depths, weights,
left.low, left.high
);
for (at = left.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
a.members[a.cnt++].packd = bases[at] | ((unsigned long)at << 32);
for (others = left.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
a.members[a.cnt++].packd = bases[others] | ((unsigned long)others << 32);
a.members[a.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);
common = nearest_common_ancestor(
dist, base_ids[vertex_cnt], vertex_cnt,
base_ids, ancestral_bases,
base_depths, weights,
right.low, right.high
);
for (at = right.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
b.members[b.cnt++].packd = bases[at] | ((unsigned long)at << 32);
for (others = right.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
b.members[b.cnt++].packd = bases[others] | ((unsigned long)others << 32);
b.members[b.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);
unsigned long total = 0;
for (at = 0; at < a.cnt; at++)
for (others = 0; others < b.cnt;
total += count_pairs(
block_cnt, levels, pairs[0], overlapping, tree,
a.members[at], b.members[others++]
)
);
printf("%lu\n", total);
}
return 0;
}
In Python3 :
import sys
from filecmp import cmp
from os import linesep
from time import time
from collections import Counter
memoized_BIT_prevs = [list() for _ in range(10**5+1)]
for i in range(1,10**5+1):
next = i + (i & -i)
if next >= len(memoized_BIT_prevs):
continue
memoized_BIT_prevs[next].append(i)
class UF(object):
__slots__=['uf','ranks']
def __init__(self):
self.uf = {} #keeps union-find structure
self.ranks = {}
def UFadd(self,a):
self.uf[a] = a #add curr to union-find
self.ranks[a] = 0
def UFfind(self,a):
uf = self.uf
curr = a
while curr != uf[curr]:
next = uf[uf[curr]] #does not fully path compress
uf[curr]=next
curr = next
return curr
def UFcombine(self,a,b):
uf = self.uf
ranks = self.ranks
a_top = self.UFfind(a)
b_top = self.UFfind(b)
rank_a = ranks[a_top]
rank_b = ranks[b_top]
if rank_a < rank_b:
uf[a_top] = b_top
elif rank_a > rank_b:
uf[b_top] = a_top
else:
uf[b_top] = a_top
ranks[a_top] += 1
class mycounter(object):
__slots__ = ['c']
def __init__(self,key = None):
if key is None:
self.c = {}
else:
self.c = {key:1}
def inc(self,key):
self.c[key] = self.c.get(key,0) + 1
def addto(self,other):
for key,otherval in other.c.items():
self.c[key] = self.c.get(key,0) + otherval
def subfrom(self,other,mult = 1): #subtract other from self, never neg
for key,otherval in other.c.items():
assert key in self.c
self.c[key] = self.c[key] - mult*otherval
def innerprodwith(self,other):
X = self.c
Y = other.c
if len(X) > len(Y):
X,Y = Y,X
return sum(X[i]*Y[i] for i in X if i in Y)
class mycounter2(object):
__slots__ = ['q','doubleneg']
def __init__(self):
self.q = []
self.doubleneg = None #will be list
def inc(self,key):
self.q.append(key)
def addto(self,other):
self.q.extend(other.q)
def subfrom(self,other,mult = 1): #subtract other from self, never neg
assert self.doubleneg is None and other.doubleneg is None
self.doubleneg = other.q
def innerprodwith(self,other):
X = self
Y = other
if len(X.q) > len(Y.q):
X,Y = Y,X
Xq = X.q
Yq = Y.q
Xn = X.doubleneg
Yn = Y.doubleneg
S = set(Xq)
S.update(Xn)
return sum((Xq.count(s)-2*Xn.count(s))*(Yq.count(s)-2*Yn.count(s)) for s in S)
class geodcounter(object):
__slots__ = ['size','C','above','below','values','lca','dists']
def __init__(self,neb,root,vals,pairs):
self.size = len(neb)
self.C = [mycounter() for _ in range(self.size)]
self.above = [None]*self.size
self.below = [list() for _ in range(self.size)]
self.values = vals
self.lca = {}
self.dists = [None]*self.size
above = self.above
below = self.below
lca = self.lca
dists = self.dists
geod = [None] #one-indexed to fit BIT
gpush = geod.append
gpop = geod.pop
height = 0
traverse_stack = [(root,None)]
tpush = traverse_stack.append
tpop = traverse_stack.pop
visited = set()
ancestors = {}
lca_done = set()
uf = UF()
while(traverse_stack):
curr,parent = tpop()
if curr is None:
break
if curr not in visited:
dists[curr] = height
height += 1
aboveht = height - (height & -height)
above[curr] = geod[aboveht]
for nextht in memoized_BIT_prevs[height]:
below[geod[nextht]].append(curr)
uf.UFadd(curr)
visited.add(curr)
tpush((curr,parent))
gpush(curr)
ancestors[curr] = curr
for child in neb[curr]:
if child in visited:
continue
tpush((child,curr))
else:
gcurr = gpop()
assert gcurr == curr
# tho self.below not complete yet,
# it is for subtree rooted at curr, so OK to call notice
self.notice(curr)
height -= 1
#this portion implements Tarjan's LCA alg
lca_done.add(curr)
for v in pairs[curr]:
if v in lca_done:
rep_v = uf.UFfind(v)
anc = ancestors[rep_v]
assert (curr,v) not in lca and (v,curr) not in lca
lca[(curr,v)] = anc
lca[(v,curr)] = anc
if parent is not None:
uf.UFcombine(curr,parent)
ancestors[uf.UFfind(curr)] = parent
def notice(self,node):
val = self.values[node]
C = self.C
below = self.below
stack = [node]
pop = stack.pop
extend = stack.extend
while stack:
curr = pop()
C[curr].inc(val)
extend(below[curr]) #note that the ORDER we visit them doesn't matter
def find(self,node):
acc = mycounter()
C = self.C
above = self.above
while node is not None:
acc.addto(C[node])
node = above[node]
return acc
def get_incidence(self,a,b):
l = self.lca[(a,b)]
find_a = self.find(a)
#find_b = self.find(b)
#delta = mycounter(self.values[l])
#pos_part = find_a + find_b + delta
find_a.addto(self.find(b))
find_a.inc(self.values[l])
find_a.subfrom(self.find(l),2)
#answer = pos_part - find_l - find_l
#order matters b/c of how subtraction of counters works
return find_a
def size_intersection(self,a,b,c,d):
prs = ((a,b),(c,d),(a,c),(a,d),(b,c),(b,d))
verts = list(self.lca[x] for x in prs)
key1,key2 = verts[0],verts[1]
check = Counter(verts)
if check[key1] == 1 or check[key2] == 1:
return 0 #empty intersection
most = check.most_common()
counts = list(freq for val,freq in most)
if counts == [6] or counts == [3,3]:
return 1 #intersect in one point
elif counts == [5,1] or counts == [3,2,1]: #root meets at or beyond endpt of intersection
close = most[-2][0]
far = most[-1][0]
return (self.dists[far] - self.dists[close])+1 #size, not length
elif counts == [4,1,1]:
left = most[1][0]
right = most[2][0]
mid = most[0][0]
return (self.dists[left] + self.dists[right] - 2*self.dists[mid])+ 1 #size, not length
else:
raise RuntimeError
def process_commands(self,commands):
measure_int = self.size_intersection
get_incidence = self.get_incidence
timer = 0
for w,x,y,z in commands:
A = get_incidence(w,x)
B = get_incidence(y,z)
innerprod = A.innerprodwith(B)
len_inter = measure_int(w,x,y,z)
ans = innerprod - len_inter
assert ans >= 0
yield ans
def test(num = None):
if num is None:
inp = sys.stdin
out = sys.stdout
else:
inp = open('./input'+num+'.txt')
out = open('./myoutput'+num+'.txt','w')
start_time = time()
N,Q = tuple(map(int,inp.readline().strip().split(' ')))
vals = [0] #one-indexed so c_i = C[i]
vals.extend(map(int,inp.readline().strip().split(' ')))
neb = [list() for x in range(N+1)]
for _ in range(N-1):
a, b = tuple(map(int,inp.readline().strip().split(' ')))
neb[a].append(b)
neb[b].append(a)
pairs = [set() for _ in range(N+1)]
commands = []
for _ in range(Q):
w,x,y,z = tuple(map(int,inp.readline().strip().split(' ')))
pairs[w].update((x,y,z))
pairs[x].update((w,y,z))
pairs[y].update((w,x,z))
pairs[z].update((w,x,y))
commands.append((w,x,y,z))
count_geod = geodcounter(neb,1,vals,pairs) #make 1 root arbitrarily
for answer in count_geod.process_commands(commands):#,desireds):
print(answer,file=out)
end_time = time()
if num is not None:
inp.close()
if num != '00':
remove_chars = len(linesep)
out.truncate(out.tell() - remove_chars) #strip trailing newline
out.close()
succeeded = cmp('myoutput'+num+'.txt','output'+num+'.txt')
outcome = " Success" if succeeded else " Failure"
print("#"+num+outcome,"in {:f}s".format(end_time-start_time))
test()
View More Similar Problems
Direct Connections
Enter-View ( EV ) is a linear, street-like country. By linear, we mean all the cities of the country are placed on a single straight line - the x -axis. Thus every city's position can be defined by a single coordinate, xi, the distance from the left borderline of the country. You can treat all cities as single points. Unfortunately, the dictator of telecommunication of EV (Mr. S. Treat Jr.) do
View Solution →Subsequence Weighting
A subsequence of a sequence is a sequence which is obtained by deleting zero or more elements from the sequence. You are given a sequence A in which every element is a pair of integers i.e A = [(a1, w1), (a2, w2),..., (aN, wN)]. For a subseqence B = [(b1, v1), (b2, v2), ...., (bM, vM)] of the given sequence : We call it increasing if for every i (1 <= i < M ) , bi < bi+1. Weight(B) =
View Solution →Kindergarten Adventures
Meera teaches a class of n students, and every day in her classroom is an adventure. Today is drawing day! The students are sitting around a round table, and they are numbered from 1 to n in the clockwise direction. This means that the students are numbered 1, 2, 3, . . . , n-1, n, and students 1 and n are sitting next to each other. After letting the students draw for a certain period of ti
View Solution →Mr. X and His Shots
A cricket match is going to be held. The field is represented by a 1D plane. A cricketer, Mr. X has N favorite shots. Each shot has a particular range. The range of the ith shot is from Ai to Bi. That means his favorite shot can be anywhere in this range. Each player on the opposite team can field only in a particular range. Player i can field from Ci to Di. You are given the N favorite shots of M
View Solution →Jim and the Skyscrapers
Jim has invented a new flying object called HZ42. HZ42 is like a broom and can only fly horizontally, independent of the environment. One day, Jim started his flight from Dubai's highest skyscraper, traveled some distance and landed on another skyscraper of same height! So much fun! But unfortunately, new skyscrapers have been built recently. Let us describe the problem in one dimensional space
View Solution →Palindromic Subsets
Consider a lowercase English alphabetic letter character denoted by c. A shift operation on some c turns it into the next letter in the alphabet. For example, and ,shift(a) = b , shift(e) = f, shift(z) = a . Given a zero-indexed string, s, of n lowercase letters, perform q queries on s where each query takes one of the following two forms: 1 i j t: All letters in the inclusive range from i t
View Solution →