Counting On a Tree
Problem Statement :
Taylor loves trees, and this new challenge has him stumped! Consider a tree, t, consisting of n nodes. Each node is numbered from 1 to n, and each node i has an integer, ci, attached to it. A query on tree t takes the form w x y z. To process a query, you must print the count of ordered pairs of integers ( i , j ) such that the following four conditions are all satisfied: the path from node to node . path from node to node . Given t and q queries, process each query in order, printing the pair count for each query on a new line. Input Format The first line contains two space-separated integers describing the respective values of n (the number of nodes) and q (the number of queries). The second line contains space-separated integers describing the respective values of each node (i.e., c1 , c2, . . . , cn ). Each of the n - 1 subsequent lines contains two space-separated integers, u and v , defining a bidirectional edge between nodes u and v. Each of the q subsequent lines contains a w x y z query, defined above. Constraints 1 <= n <= 10^5 1 <= q <= 50000 1 <= ci <= 10^9 1 <= u, v , w, x, y , z <= n Scoring for this problem is Binary, that means you have to pass all the test cases to get a positive score. Output Format For each query, print the count of ordered pairs of integers satisfying the four given conditions on a new line.
Solution :
Solution in C :
In C++ :
#include <bits/stdc++.h>
using namespace std;
const int kMaxN = 300005;
vector<int> G[kMaxN];
int Color[kMaxN], Linear[kMaxN], Beg[kMaxN],
End[kMaxN], Depth[kMaxN];
long long Answer[kMaxN];
int n, colors, timer;
int Parent[20][kMaxN];
void DFS(int node) {
Beg[node] = ++timer;
Linear[timer] = +Color[node];
for(auto vec : G[node]) {
if(!Beg[vec]) {
Parent[0][vec] = node;
Depth[vec] = Depth[node] + 1;
for(int i = 1; Parent[i - 1][vec]; ++i) {
Parent[i][vec] = Parent[i - 1][Parent[i - 1][vec]];
}
DFS(vec);
}
}
End[node] = ++timer;
Linear[timer] = -Color[node];
}
int LCA(int a, int b) {
if(Depth[a] < Depth[b])
swap(a, b);
int d = Depth[a] - Depth[b];
for(int i = 0; d; ++i, d /= 2)
if(d % 2)
a = Parent[i][a];
if(a == b)
return a;
for(int i = 19; i >= 0; --i)
if(Parent[i][a] != Parent[i][b])
a = Parent[i][a], b = Parent[i][b];
assert(Parent[0][a] == Parent[0][b]);
return Parent[0][a];
}
struct Query {
int a, b, i, sgn;
};
vector<Query> Queries;
void AddQuery(int a, int b, int i, int sgn) {
if(a == 0 || b == 0) return;
Queries.push_back(Query {a, b, i, sgn});
//cerr << "Added: " << a << " " << b << " at ind " << i << " with sign " << sgn << '\n';
}
void AddQuery(int a, int b, int c, int d, int i) {
AddQuery(a, c, i, 1);
AddQuery(b, c, i, -1);
AddQuery(a, d, i, -1);
AddQuery(b, d, i, 1);
}
int Count[2][kMaxN];
long long global_ans;
void Add(int at, int col) {
int pos = abs(col);
int delta = col / pos;
global_ans -= 1LL * Count[0][pos] * Count[1][pos];
Count[at][pos] += delta;
global_ans += 1LL * Count[0][pos] * Count[1][pos];
}
void SolveQueries() {
for(auto &q : Queries) {
Answer[q.i] -= q.sgn * (1 + Depth[LCA(q.a, q.b)]);
q.a = Beg[q.a];
q.b = Beg[q.b];
if(q.a > q.b)
swap(q.a, q.b);
}
const int kMagic = 256;
sort(Queries.begin(), Queries.end(), [](Query &a, Query &b) {
if(a.a / kMagic == b.a / kMagic)
return a.b < b.b;
return a.a < b.a;
});
int b = 0, e = 0;
for(auto &q : Queries) {
while(b < q.a) Add(0, Linear[++b]);
while(e < q.b) Add(1, Linear[++e]);
while(b > q.a) Add(0, -Linear[b--]);
while(e > q.b) Add(1, -Linear[e--]);
Answer[q.i] += global_ans * q.sgn;
}
}
int main() {
int q;
cin >> n >> q;
map<int, int> normMap;
for(int i = 1; i <= n; ++i) {
cin >> Color[i];
auto &norm = normMap[Color[i]];
if(norm == 0)
norm = ++colors;
Color[i] = norm;
}
for(int i = 2; i <= n; ++i) {
int a, b;
cin >> a >> b;
G[a].push_back(b);
G[b].push_back(a);
}
DFS(1);
Depth[0] = -1;
/*
for(int i = 1; i <= timer; ++i)
cerr << Linear[i] << " ";
cerr << endl;
*/
for(int i = 1; i <= q; ++i) {
int a1, b1, a2, b2;
cin >> a1 >> b1 >> a2 >> b2;
int l1 = LCA(a1, b1);
int l2 = LCA(a2, b2);
AddQuery(a1, Parent[0][l1], a2, Parent[0][l2], i);
AddQuery(a1, Parent[0][l1], b2, l2, i);
AddQuery(b1, l1, a2, Parent[0][l2], i);
AddQuery(b1, l1, b2, l2, i);
}
SolveQueries();
for(int i = 1; i <= q; ++i)
cout << Answer[i] << "\n";
cout << endl;
return 0;
}
In Java :
import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;
public class Solution {
// Complete the solve function below.
static int[] solve(int[] c, int[][] tree, int[][] queries) {
}
private static final Scanner scanner = new Scanner(System.in);
public static void main(String[] args) throws IOException {
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));
String[] nq = scanner.nextLine().split(" ");
int n = Integer.parseInt(nq[0]);
int q = Integer.parseInt(nq[1]);
int[] c = new int[n];
String[] cItems = scanner.nextLine().split(" ");
scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");
for (int cItr = 0; cItr < n; cItr++) {
int cItem = Integer.parseInt(cItems[cItr]);
c[cItr] = cItem;
}
int[][] tree = new int[n-1][2];
for (int treeRowItr = 0; treeRowItr < n-1; treeRowItr++) {
String[] treeRowItems = scanner.nextLine().split(" ");
scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");
for (int treeColumnItr = 0; treeColumnItr < 2; treeColumnItr++) {
int treeItem = Integer.parseInt(treeRowItems[treeColumnItr]);
tree[treeRowItr][treeColumnItr] = treeItem;
}
}
int[][] queries = new int[q][4];
for (int queriesRowItr = 0; queriesRowItr < q; queriesRowItr++) {
String[] queriesRowItems = scanner.nextLine().split(" ");
scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");
for (int queriesColumnItr = 0; queriesColumnItr < 4; queriesColumnItr++) {
int queriesItem = Integer.parseInt(queriesRowItems[queriesColumnItr]);
queries[queriesRowItr][queriesColumnItr] = queriesItem;
}
}
int[] result = solve(c, tree, queries);
for (int resultItr = 0; resultItr < result.length; resultItr++) {
bufferedWriter.write(String.valueOf(result[resultItr]));
if (resultItr != result.length - 1) {
bufferedWriter.write("\n");
}
}
bufferedWriter.newLine();
bufferedWriter.close();
scanner.close();
}
}
In C :
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>
#define floor_log2_X86(self) (__builtin_clz(self) ^ 31U)
#define floor_log2 floor_log2_X86
void heap_sort(unsigned *self, unsigned *weights, unsigned length) {
unsigned
at = length >> 1,
member,
node;
for (self--; at; self[node >> 1] = member) {
member = self[at];
for (node = at-- << 1; node <= length; node <<= 1) {
node |= (node < length) && (weights[self[node]] < weights[self[node | 1]]);
if (weights[self[node]] < weights[member])
break ;
self[node >> 1] = self[node];
}
}
for (; length > 1; self[at >> 1] = member) {
member = self[length];
self[length--] = self[1];
for (at = 2; at <= length; at <<= 1) {
at |= (at < length) && (weights[self[at]] < weights[self[at | 1]]);
if (weights[self[at]] < weights[member])
break ;
self[at >> 1] = self[at];
}
}
}
void compress(unsigned length, unsigned values[length]) {
unsigned
at,
order[length];
unsigned long sum = 0x0000000100000000UL;
for (at = 0; at < (length >> 1); sum += 0x0000000200000002UL)
((unsigned long *)order)[at++] = sum;
order[length - 1] = length - 1;
heap_sort(order, values, length);
unsigned roots[length], seen = 1, max = 0, others;
for (roots[at = 0] = -1U; at < length; roots[seen++] = at - 1) {
for (others = at; (at < length) && values[order[at]] == values[order[others]]; at++);
if (max < (at - others))
max = (at - others);
}
unsigned
indices[max + 1],
ranks[seen];
memset(indices, 0, sizeof(indices));
for (at = 0; ++at < seen; indices[roots[at] - roots[at - 1]]++);
for (at = max; at--; indices[at] += indices[at + 1]);
for (at = seen; --at; ranks[--indices[roots[at] - roots[at - 1]]] = at);
for (; at < (seen - 1); at++)
for (others = roots[ranks[at] - 1]; ++others <= roots[ranks[at]]; values[order[others]] = at);
}
static inline unsigned nearest_common_ancestor(
unsigned depth,
unsigned base_cnt,
unsigned vertex_cnt,
unsigned base_ids[vertex_cnt],
unsigned bases[base_cnt][depth],
unsigned char depths[base_cnt],
unsigned weights[vertex_cnt],
unsigned lower,
unsigned upper
) {
if (upper < (lower + weights[lower]))
return lower;
if (depths[upper] > depths[lower])
upper = bases[base_ids[upper]][depths[upper] - depths[lower] - 1];
if (upper < lower)
return upper;
unsigned *others = bases[base_ids[upper]];
for (; depth > 1; depth >>= 1)
if (others[depth >> 1] > lower) {
others += depth >> 1;
depth += depth & 1U;
}
return others[others[0] > lower];
}
typedef union {
unsigned long packd;
struct {
int low, high;
};
} range_t;
typedef struct {
unsigned
*members,
*colors,
*indices,
*locations;
} colored_tree_t;
unsigned long query_all(colored_tree_t *self, unsigned at, range_t other) {
unsigned
color = self->colors[at],
length = self->indices[color + 1] - self->indices[color],
*base = &self->members[self->indices[color]];
if (other.high < base[0] || other.low > base[length - 1])
return 0;
if (self->colors[other.low] != color) {
if (at < other.low) {
base += self->locations[at] - self->indices[color];
length = self->indices[color + 1] - self->locations[at];
} else
length = self->locations[at] - self->indices[color]; // at > other.low
for (; length > 1; length >>= 1)
if (base[length >> 1] < other.low) {
base += length >> 1;
length += length & 1;
}
base += (base[0] < other.low);
} else
base += (self->locations[other.low] - self->indices[color]);
if (base[0] > other.high)
return 0;
unsigned *ceil;
if (self->colors[other.high] != color) {
ceil = (at > base[0] && at < other.high) ? &self->members[self->locations[at]] : base;
for (length = self->indices[color + 1] - self->locations[ceil[0]]; length > 1; length >>= 1)
if (ceil[length >> 1] <= other.high) {
ceil += length >> 1;
length += length & 1;
}
ceil -= (ceil[0] > other.high);
} else
ceil = &self->members[self->locations[other.high]];
return ceil - base + 1 - (at >= other.low && at <= other.high);
}
unsigned long count_pairs(
unsigned cnt,
unsigned length,
unsigned long pairs[cnt][cnt],
unsigned *overlapping,
colored_tree_t *tree,
range_t self,
range_t other
) {
unsigned long count = 0;
for (; (self.low % length) && (self.low <= self.high); count += query_all(tree, self.low++, other));
for (; ((self.high + 1) % length) && (self.low <= self.high); count += query_all(tree, self.high--, other));
if (self.low <= self.high) {
for (; (other.low % length) && (other.low <= other.high); count += query_all(tree, other.low++, self));
for (; ((other.high + 1) % length) && (other.low <= other.high); count += query_all(tree, other.high--, self));
if (other.low <= other.high) {
self.low /= length;
self.high /= length;
other.low /= length;
other.high /= length;
if (self.low > other.low) {
self.packd ^= other.packd;
other.packd ^= self.packd;
self.packd ^= other.packd;
}
unsigned high = (self.high < other.low) ? self.high : (other.low - 1);
count +=
pairs[high][other.high]
- pairs[high][other.low - 1UL]
- pairs[self.low - 1UL][other.high]
+ pairs[self.low - 1UL][other.low - 1UL];
self.low = high + 1;
if (self.high > other.high) {
self.packd ^= other.packd;
other.packd ^= self.packd;
self.packd ^= other.packd;
}
if (self.low <= self.high)
count +=
(overlapping[self.high] - overlapping[self.low - 1UL])
+ ((
pairs[self.high][self.high]
- pairs[self.high][self.low - 1UL]
- pairs[self.low - 1UL][self.high]
+ pairs[self.low - 1UL][self.low - 1UL]
) << 1) + (
pairs[self.high][other.high]
- pairs[self.high][self.high]
- pairs[self.low - 1UL][other.high]
+ pairs[self.low - 1UL][self.high]
);
}
}
return count;
}
int main() {
unsigned at, vertex_cnt;
unsigned short query_cnt;
scanf("%u %hu", &vertex_cnt, &query_cnt);
unsigned colors[vertex_cnt + 1];
for (at = 0; at < vertex_cnt; scanf("%u", &colors[at++]));
colors[at] = 0xFFFFFFFFU;
compress(at + 1, colors);
unsigned ancestors[at + 1];
{
unsigned ancestor, descendant;
for (memset(ancestors, 0xFFU, sizeof(ancestors)); --at; ancestors[descendant] = ancestor) {
scanf("%u %u", &ancestor, &descendant);
--ancestor;
if (ancestors[--descendant] != 0xFFFFFFFFU) {
unsigned root = descendant, next;
for (; ancestor != 0xFFFFFFFFU; ancestor = next) {
next = ancestors[ancestor];
ancestors[ancestor] = root;
root = ancestor;
}
for (; ancestors[descendant] != 0xFFFFFFFFU; descendant = next) {
next = ancestors[descendant];
ancestors[descendant] = ancestor;
ancestor = descendant;
}
}
}
for (ancestor = 0xFFFFFFFFU; at != 0xFFFFFFFFU; at = descendant) {
descendant = ancestors[at];
ancestors[at] = ancestor;
ancestor = at;
}
}
unsigned
others,
ids[vertex_cnt + 1],
weights[vertex_cnt],
bases[vertex_cnt + 1],
history[vertex_cnt];
unsigned char
base_depths[vertex_cnt],
dist = 0;
{
unsigned
history[vertex_cnt],
indices[vertex_cnt + 1],
descendants[vertex_cnt];
memset(indices, 0, sizeof(indices));
for (ancestors[vertex_cnt] = (at = vertex_cnt); at; indices[ancestors[at--]]++);
for (; ++at <= vertex_cnt; indices[at] += indices[at - 1]);
for (; --at; descendants[--indices[ancestors[at]]] = at);
history[0] = 0;
memset(weights, 0, sizeof(weights));
for (at = 1; at--; )
if (weights[history[at]])
for (others = indices[history[at]];
others < indices[history[at] + 1];
weights[history[at]] += weights[descendants[others++]]);
else {
weights[history[at]] = 1;
memcpy(
&history[at + 1],
&descendants[indices[history[at]]],
(indices[history[at] + 1] - indices[history[at]]) * sizeof(descendants[0])
);
at += indices[history[at] + 1] - indices[history[at]] + 1;
}
unsigned
orig_ancestors[vertex_cnt + 1],
orig_colors[vertex_cnt + 1],
orig_weights[vertex_cnt];
memcpy(orig_ancestors, ancestors, sizeof(ancestors));
memcpy(orig_weights, weights, sizeof(weights));
memcpy(orig_colors, colors, sizeof(colors));
base_depths[0] = (bases[0] = (ids[0] = 0));
bases[vertex_cnt] = (ids[vertex_cnt] = vertex_cnt);
for (at = 1; at--;) {
unsigned
id = ids[history[at]],
base = bases[id++],
branches = indices[history[at] + 1] - indices[history[at]];
heap_sort(&descendants[indices[history[at]]], orig_weights, branches);
memcpy(&history[at], &descendants[indices[history[at]]], branches * sizeof(descendants[0]));
for (others = (at += branches); branches--; base = id) {
ids[history[--others]] = id;
ancestors[id] = ids[orig_ancestors[history[others]]];
weights[id] = orig_weights[history[others]];
colors[id] = orig_colors[history[others]];
bases[id] = base;
base_depths[id] = base_depths[ancestors[id]] + (base == id);
if (dist < base_depths[id])
dist = base_depths[id];
id += weights[id];
}
}
}
unsigned base_ids[vertex_cnt + 1];
for (base_ids[0] = (others = (at = 0)); others < vertex_cnt; base_ids[others] = base_ids[at] + 1)
for (at = others; bases[at] == bases[others]; base_ids[others++] = base_ids[at]);
unsigned ancestral_bases[base_ids[vertex_cnt]][dist];
for (ancestors[0] = 0; others--; ancestral_bases[base_ids[others]][0] = ancestors[others]);
while ((++others + 1) < dist)
for (at = 0; ++at < base_ids[vertex_cnt];
ancestral_bases[at][others + 1] = ancestors[bases[ancestral_bases[at][others]]]);
unsigned
indexed_colors[colors[vertex_cnt] + 2],
members[vertex_cnt + 1];
memset(indexed_colors, 0, sizeof(indexed_colors));
for (at = vertex_cnt + 1; at--; indexed_colors[colors[at]]++);
for (; ++at < colors[vertex_cnt]; indexed_colors[at + 1] += indexed_colors[at]);
for (at = vertex_cnt + 1; at--; members[--indexed_colors[colors[at]]] = at);
indexed_colors[colors[vertex_cnt] + 1] = indexed_colors[colors[vertex_cnt]];
unsigned
levels = floor_log2(vertex_cnt) + 1,
block_cnt = (vertex_cnt / levels) + 1,
locations[vertex_cnt + 1],
overlapping[block_cnt];
unsigned long (*pairs)[block_cnt][block_cnt] = calloc(
(1 + block_cnt) * (1 + block_cnt),
sizeof(pairs[0][0][0])
);
pairs = (void *)&pairs[0][1][1];
for (at = vertex_cnt + 1; at--; locations[members[at]] = at);
memset(overlapping, 0, sizeof(overlapping));
for (at = 0; (indexed_colors[at + 1] - indexed_colors[at]) > 1; at++) {
others = indexed_colors[at];
unsigned
block_bases[indexed_colors[at + 1] - others + 1],
cnt = 1;
for (block_bases[0] = members[others]; at == colors[members[++others]]; )
if ((members[others] / levels) != (block_bases[cnt - 1] / levels))
block_bases[cnt++] = members[others];
block_bases[cnt] = members[others];
for (others = 0; others < cnt; others++) {
unsigned long density = locations[block_bases[others + 1]] - locations[block_bases[others]];
overlapping[block_bases[others] / levels] += density * (density - 1);
unsigned block = others;
for (; ++block < cnt; pairs[0][block_bases[others] / levels][block_bases[block] / levels]
+= density * (locations[block_bases[block + 1]] - locations[block_bases[block]]));
}
}
for (at = 0; ++at < block_cnt; overlapping[at] += overlapping[at - 1])
pairs[0][0][at] += pairs[0][0][at - 1];
for (at = 0; ++at < block_cnt; )
for (others = 0; ++others < block_cnt; pairs[0][at][others] += pairs[0][at][others - 1]);
for (at = 0; ++at < block_cnt; )
for (others = 0; others < block_cnt; others++)
pairs[0][at][others] += pairs[0][at - 1][others];
colored_tree_t *tree = &(colored_tree_t) {
.members = members,
.colors = colors,
.indices = indexed_colors,
.locations = locations
};
while (query_cnt--) {
range_t left, right;
scanf("%u %u %u %u", &left.low, &left.high, &right.low, &right.high);
left.packd -= 0x0000000100000001UL;
right.packd -= 0x0000000100000001UL;
left.low = ids[left.low];
left.high = ids[left.high];
right.low = ids[right.low];
right.high = ids[right.high];
if (left.high < left.low)
left.packd = (left.packd << 32) | (left.packd >> 32);
if (right.high < right.low)
right.packd = (right.packd << 32) | (right.packd >> 32);
if (right.high < left.low) {
left.packd ^= right.packd;
right.packd ^= left.packd;
left.packd ^= right.packd;
}
struct {
range_t members[32];
unsigned cnt;
}
a = {.cnt = 0},
b = {.cnt = 0};
unsigned common = nearest_common_ancestor(
dist, base_ids[vertex_cnt], vertex_cnt,
base_ids, ancestral_bases,
base_depths, weights,
left.low, left.high
);
for (at = left.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
a.members[a.cnt++].packd = bases[at] | ((unsigned long)at << 32);
for (others = left.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
a.members[a.cnt++].packd = bases[others] | ((unsigned long)others << 32);
a.members[a.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);
common = nearest_common_ancestor(
dist, base_ids[vertex_cnt], vertex_cnt,
base_ids, ancestral_bases,
base_depths, weights,
right.low, right.high
);
for (at = right.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
b.members[b.cnt++].packd = bases[at] | ((unsigned long)at << 32);
for (others = right.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
b.members[b.cnt++].packd = bases[others] | ((unsigned long)others << 32);
b.members[b.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);
unsigned long total = 0;
for (at = 0; at < a.cnt; at++)
for (others = 0; others < b.cnt;
total += count_pairs(
block_cnt, levels, pairs[0], overlapping, tree,
a.members[at], b.members[others++]
)
);
printf("%lu\n", total);
}
return 0;
}
In Python3 :
import sys
from filecmp import cmp
from os import linesep
from time import time
from collections import Counter
memoized_BIT_prevs = [list() for _ in range(10**5+1)]
for i in range(1,10**5+1):
next = i + (i & -i)
if next >= len(memoized_BIT_prevs):
continue
memoized_BIT_prevs[next].append(i)
class UF(object):
__slots__=['uf','ranks']
def __init__(self):
self.uf = {} #keeps union-find structure
self.ranks = {}
def UFadd(self,a):
self.uf[a] = a #add curr to union-find
self.ranks[a] = 0
def UFfind(self,a):
uf = self.uf
curr = a
while curr != uf[curr]:
next = uf[uf[curr]] #does not fully path compress
uf[curr]=next
curr = next
return curr
def UFcombine(self,a,b):
uf = self.uf
ranks = self.ranks
a_top = self.UFfind(a)
b_top = self.UFfind(b)
rank_a = ranks[a_top]
rank_b = ranks[b_top]
if rank_a < rank_b:
uf[a_top] = b_top
elif rank_a > rank_b:
uf[b_top] = a_top
else:
uf[b_top] = a_top
ranks[a_top] += 1
class mycounter(object):
__slots__ = ['c']
def __init__(self,key = None):
if key is None:
self.c = {}
else:
self.c = {key:1}
def inc(self,key):
self.c[key] = self.c.get(key,0) + 1
def addto(self,other):
for key,otherval in other.c.items():
self.c[key] = self.c.get(key,0) + otherval
def subfrom(self,other,mult = 1): #subtract other from self, never neg
for key,otherval in other.c.items():
assert key in self.c
self.c[key] = self.c[key] - mult*otherval
def innerprodwith(self,other):
X = self.c
Y = other.c
if len(X) > len(Y):
X,Y = Y,X
return sum(X[i]*Y[i] for i in X if i in Y)
class mycounter2(object):
__slots__ = ['q','doubleneg']
def __init__(self):
self.q = []
self.doubleneg = None #will be list
def inc(self,key):
self.q.append(key)
def addto(self,other):
self.q.extend(other.q)
def subfrom(self,other,mult = 1): #subtract other from self, never neg
assert self.doubleneg is None and other.doubleneg is None
self.doubleneg = other.q
def innerprodwith(self,other):
X = self
Y = other
if len(X.q) > len(Y.q):
X,Y = Y,X
Xq = X.q
Yq = Y.q
Xn = X.doubleneg
Yn = Y.doubleneg
S = set(Xq)
S.update(Xn)
return sum((Xq.count(s)-2*Xn.count(s))*(Yq.count(s)-2*Yn.count(s)) for s in S)
class geodcounter(object):
__slots__ = ['size','C','above','below','values','lca','dists']
def __init__(self,neb,root,vals,pairs):
self.size = len(neb)
self.C = [mycounter() for _ in range(self.size)]
self.above = [None]*self.size
self.below = [list() for _ in range(self.size)]
self.values = vals
self.lca = {}
self.dists = [None]*self.size
above = self.above
below = self.below
lca = self.lca
dists = self.dists
geod = [None] #one-indexed to fit BIT
gpush = geod.append
gpop = geod.pop
height = 0
traverse_stack = [(root,None)]
tpush = traverse_stack.append
tpop = traverse_stack.pop
visited = set()
ancestors = {}
lca_done = set()
uf = UF()
while(traverse_stack):
curr,parent = tpop()
if curr is None:
break
if curr not in visited:
dists[curr] = height
height += 1
aboveht = height - (height & -height)
above[curr] = geod[aboveht]
for nextht in memoized_BIT_prevs[height]:
below[geod[nextht]].append(curr)
uf.UFadd(curr)
visited.add(curr)
tpush((curr,parent))
gpush(curr)
ancestors[curr] = curr
for child in neb[curr]:
if child in visited:
continue
tpush((child,curr))
else:
gcurr = gpop()
assert gcurr == curr
# tho self.below not complete yet,
# it is for subtree rooted at curr, so OK to call notice
self.notice(curr)
height -= 1
#this portion implements Tarjan's LCA alg
lca_done.add(curr)
for v in pairs[curr]:
if v in lca_done:
rep_v = uf.UFfind(v)
anc = ancestors[rep_v]
assert (curr,v) not in lca and (v,curr) not in lca
lca[(curr,v)] = anc
lca[(v,curr)] = anc
if parent is not None:
uf.UFcombine(curr,parent)
ancestors[uf.UFfind(curr)] = parent
def notice(self,node):
val = self.values[node]
C = self.C
below = self.below
stack = [node]
pop = stack.pop
extend = stack.extend
while stack:
curr = pop()
C[curr].inc(val)
extend(below[curr]) #note that the ORDER we visit them doesn't matter
def find(self,node):
acc = mycounter()
C = self.C
above = self.above
while node is not None:
acc.addto(C[node])
node = above[node]
return acc
def get_incidence(self,a,b):
l = self.lca[(a,b)]
find_a = self.find(a)
#find_b = self.find(b)
#delta = mycounter(self.values[l])
#pos_part = find_a + find_b + delta
find_a.addto(self.find(b))
find_a.inc(self.values[l])
find_a.subfrom(self.find(l),2)
#answer = pos_part - find_l - find_l
#order matters b/c of how subtraction of counters works
return find_a
def size_intersection(self,a,b,c,d):
prs = ((a,b),(c,d),(a,c),(a,d),(b,c),(b,d))
verts = list(self.lca[x] for x in prs)
key1,key2 = verts[0],verts[1]
check = Counter(verts)
if check[key1] == 1 or check[key2] == 1:
return 0 #empty intersection
most = check.most_common()
counts = list(freq for val,freq in most)
if counts == [6] or counts == [3,3]:
return 1 #intersect in one point
elif counts == [5,1] or counts == [3,2,1]: #root meets at or beyond endpt of intersection
close = most[-2][0]
far = most[-1][0]
return (self.dists[far] - self.dists[close])+1 #size, not length
elif counts == [4,1,1]:
left = most[1][0]
right = most[2][0]
mid = most[0][0]
return (self.dists[left] + self.dists[right] - 2*self.dists[mid])+ 1 #size, not length
else:
raise RuntimeError
def process_commands(self,commands):
measure_int = self.size_intersection
get_incidence = self.get_incidence
timer = 0
for w,x,y,z in commands:
A = get_incidence(w,x)
B = get_incidence(y,z)
innerprod = A.innerprodwith(B)
len_inter = measure_int(w,x,y,z)
ans = innerprod - len_inter
assert ans >= 0
yield ans
def test(num = None):
if num is None:
inp = sys.stdin
out = sys.stdout
else:
inp = open('./input'+num+'.txt')
out = open('./myoutput'+num+'.txt','w')
start_time = time()
N,Q = tuple(map(int,inp.readline().strip().split(' ')))
vals = [0] #one-indexed so c_i = C[i]
vals.extend(map(int,inp.readline().strip().split(' ')))
neb = [list() for x in range(N+1)]
for _ in range(N-1):
a, b = tuple(map(int,inp.readline().strip().split(' ')))
neb[a].append(b)
neb[b].append(a)
pairs = [set() for _ in range(N+1)]
commands = []
for _ in range(Q):
w,x,y,z = tuple(map(int,inp.readline().strip().split(' ')))
pairs[w].update((x,y,z))
pairs[x].update((w,y,z))
pairs[y].update((w,x,z))
pairs[z].update((w,x,y))
commands.append((w,x,y,z))
count_geod = geodcounter(neb,1,vals,pairs) #make 1 root arbitrarily
for answer in count_geod.process_commands(commands):#,desireds):
print(answer,file=out)
end_time = time()
if num is not None:
inp.close()
if num != '00':
remove_chars = len(linesep)
out.truncate(out.tell() - remove_chars) #strip trailing newline
out.close()
succeeded = cmp('myoutput'+num+'.txt','output'+num+'.txt')
outcome = " Success" if succeeded else " Failure"
print("#"+num+outcome,"in {:f}s".format(end_time-start_time))
test()
View More Similar Problems
Get Node Value
This challenge is part of a tutorial track by MyCodeSchool Given a pointer to the head of a linked list and a specific position, determine the data value at that position. Count backwards from the tail node. The tail is at postion 0, its parent is at 1 and so on. Example head refers to 3 -> 2 -> 1 -> 0 -> NULL positionFromTail = 2 Each of the data values matches its distance from the t
View Solution →Delete duplicate-value nodes from a sorted linked list
This challenge is part of a tutorial track by MyCodeSchool You are given the pointer to the head node of a sorted linked list, where the data in the nodes is in ascending order. Delete nodes and return a sorted list with each distinct value in the original list. The given head pointer may be null indicating that the list is empty. Example head refers to the first node in the list 1 -> 2 -
View Solution →Cycle Detection
A linked list is said to contain a cycle if any node is visited more than once while traversing the list. Given a pointer to the head of a linked list, determine if it contains a cycle. If it does, return 1. Otherwise, return 0. Example head refers 1 -> 2 -> 3 -> NUL The numbers shown are the node numbers, not their data values. There is no cycle in this list so return 0. head refer
View Solution →Find Merge Point of Two Lists
This challenge is part of a tutorial track by MyCodeSchool Given pointers to the head nodes of 2 linked lists that merge together at some point, find the node where the two lists merge. The merge point is where both lists point to the same node, i.e. they reference the same memory location. It is guaranteed that the two head nodes will be different, and neither will be NULL. If the lists share
View Solution →Inserting a Node Into a Sorted Doubly Linked List
Given a reference to the head of a doubly-linked list and an integer ,data , create a new DoublyLinkedListNode object having data value data and insert it at the proper location to maintain the sort. Example head refers to the list 1 <-> 2 <-> 4 - > NULL. data = 3 Return a reference to the new list: 1 <-> 2 <-> 4 - > NULL , Function Description Complete the sortedInsert function
View Solution →Reverse a doubly linked list
This challenge is part of a tutorial track by MyCodeSchool Given the pointer to the head node of a doubly linked list, reverse the order of the nodes in place. That is, change the next and prev pointers of the nodes so that the direction of the list is reversed. Return a reference to the head node of the reversed list. Note: The head node might be NULL to indicate that the list is empty.
View Solution →