Counting On a Tree


Problem Statement :


Taylor loves trees, and this new challenge has him stumped!

Consider a tree, t, consisting of n nodes. Each node is numbered from 1 to n, and each node i has an integer, ci, attached to it.

A query on tree t takes the form w x y z. To process a query, you must print the count of ordered pairs of integers ( i , j ) such that the following four conditions are all satisfied:

 the path from node  to node .
 path from node  to node .
Given t and q queries, process each query in order, printing the pair count for each query on a new line.

Input Format

The first line contains two space-separated integers describing the respective values of n (the number of nodes) and q (the number of queries).
The second line contains  space-separated integers describing the respective values of each node (i.e., c1 , c2, . . . , cn ).
Each of the n - 1 subsequent lines contains two space-separated integers, u  and v , defining a bidirectional edge between nodes u and v.
Each of the q subsequent lines contains a w x y z query, defined above.


Constraints

1   <=  n   <=  10^5
1   <=  q   <=  50000
1   <=  ci   <=  10^9
1   <=  u, v , w, x, y , z  <= n

Scoring for this problem is Binary, that means you have to pass all the test cases to get a positive score.

Output Format

For each query, print the count of ordered pairs of integers satisfying the four given conditions on a new line.



Solution :



title-img


                            Solution in C :

In  C++ :





#include <bits/stdc++.h>

using namespace std;

const int kMaxN = 300005;

vector<int> G[kMaxN];
int Color[kMaxN], Linear[kMaxN], Beg[kMaxN],
    End[kMaxN], Depth[kMaxN];
long long Answer[kMaxN];

int n, colors, timer;

int Parent[20][kMaxN];

void DFS(int node) {
    Beg[node] = ++timer;
    Linear[timer] = +Color[node];

    for(auto vec : G[node]) {
        if(!Beg[vec]) {
            Parent[0][vec] = node;
            Depth[vec] = Depth[node] + 1;

            for(int i = 1; Parent[i - 1][vec]; ++i) {
                Parent[i][vec] = Parent[i - 1][Parent[i - 1][vec]];
            }

            DFS(vec);
        }
    }

    End[node] = ++timer;
    Linear[timer] = -Color[node];
}

int LCA(int a, int b) {
    if(Depth[a] < Depth[b])
        swap(a, b);

    int d = Depth[a] - Depth[b];
    for(int i = 0; d; ++i, d /= 2)
        if(d % 2)
            a = Parent[i][a];

    if(a == b)
        return a;

    for(int i = 19; i >= 0; --i)
        if(Parent[i][a] != Parent[i][b])
            a = Parent[i][a], b = Parent[i][b];

    assert(Parent[0][a] == Parent[0][b]);
    return Parent[0][a];
}


struct Query {
    int a, b, i, sgn;
};
vector<Query> Queries;
void AddQuery(int a, int b, int i, int sgn) {
    if(a == 0 || b == 0) return;
    Queries.push_back(Query {a, b, i, sgn});
    //cerr << "Added: " << a << " " << b << " at ind " << i << " with sign " << sgn << '\n';
}
void AddQuery(int a, int b, int c, int d, int i) {
    AddQuery(a, c, i, 1);
    AddQuery(b, c, i, -1);
    AddQuery(a, d, i, -1);
    AddQuery(b, d, i, 1);
}

int Count[2][kMaxN];
long long global_ans;
void Add(int at, int col) {
    int pos = abs(col);
    int delta = col / pos;
    
    global_ans -= 1LL * Count[0][pos] * Count[1][pos];
    Count[at][pos] += delta;
    global_ans += 1LL * Count[0][pos] * Count[1][pos];
}

void SolveQueries() {
    for(auto &q : Queries) {
        Answer[q.i] -= q.sgn * (1 + Depth[LCA(q.a, q.b)]);
        
        q.a = Beg[q.a];
        q.b = Beg[q.b];
        
        if(q.a > q.b) 
            swap(q.a, q.b);
    }
    
    const int kMagic = 256;
    sort(Queries.begin(), Queries.end(), [](Query &a, Query &b) {
        if(a.a / kMagic == b.a / kMagic)
            return a.b < b.b;
        return a.a < b.a;
    });
    
    int b = 0, e = 0;
    for(auto &q : Queries) {
        while(b < q.a) Add(0, Linear[++b]);
        while(e < q.b) Add(1, Linear[++e]);
        while(b > q.a) Add(0, -Linear[b--]);
        while(e > q.b) Add(1, -Linear[e--]);
        
        Answer[q.i] += global_ans * q.sgn;
    }
    
}

int main() {
    int q;
    cin >> n >> q;

    map<int, int> normMap;
    for(int i = 1; i <= n; ++i) {
        cin >> Color[i];
        auto &norm = normMap[Color[i]];
        if(norm == 0) 
            norm = ++colors;
        Color[i] = norm;
    }

    for(int i = 2; i <= n; ++i) {
        int a, b;
        cin >> a >> b;
        G[a].push_back(b);
        G[b].push_back(a);
    }

    DFS(1);
    Depth[0] = -1;

    
/*    
    for(int i = 1; i <= timer; ++i)
        cerr << Linear[i] << " ";
    cerr << endl;
*/
    
    for(int i = 1; i <= q; ++i) {
        int a1, b1, a2, b2;
        cin >> a1 >> b1 >> a2 >> b2;
        
        int l1 = LCA(a1, b1);
        int l2 = LCA(a2, b2);

        AddQuery(a1, Parent[0][l1], a2, Parent[0][l2], i);
        AddQuery(a1, Parent[0][l1], b2, l2, i);
        AddQuery(b1, l1, a2, Parent[0][l2], i);
        AddQuery(b1, l1, b2, l2, i);

    }

    SolveQueries();

    for(int i = 1; i <= q; ++i)
        cout << Answer[i] << "\n";
    cout << endl;

    return 0;
}







In Java :






import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;

public class Solution {

    // Complete the solve function below.
    static int[] solve(int[] c, int[][] tree, int[][] queries) {


    }

    private static final Scanner scanner = new Scanner(System.in);

    public static void main(String[] args) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

        String[] nq = scanner.nextLine().split(" ");

        int n = Integer.parseInt(nq[0]);

        int q = Integer.parseInt(nq[1]);

        int[] c = new int[n];

        String[] cItems = scanner.nextLine().split(" ");
        scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

        for (int cItr = 0; cItr < n; cItr++) {
            int cItem = Integer.parseInt(cItems[cItr]);
            c[cItr] = cItem;
        }

        int[][] tree = new int[n-1][2];

        for (int treeRowItr = 0; treeRowItr < n-1; treeRowItr++) {
            String[] treeRowItems = scanner.nextLine().split(" ");
            scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

            for (int treeColumnItr = 0; treeColumnItr < 2; treeColumnItr++) {
                int treeItem = Integer.parseInt(treeRowItems[treeColumnItr]);
                tree[treeRowItr][treeColumnItr] = treeItem;
            }
        }

        int[][] queries = new int[q][4];

        for (int queriesRowItr = 0; queriesRowItr < q; queriesRowItr++) {
            String[] queriesRowItems = scanner.nextLine().split(" ");
            scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

            for (int queriesColumnItr = 0; queriesColumnItr < 4; queriesColumnItr++) {
                int queriesItem = Integer.parseInt(queriesRowItems[queriesColumnItr]);
                queries[queriesRowItr][queriesColumnItr] = queriesItem;
            }
        }

        int[] result = solve(c, tree, queries);

        for (int resultItr = 0; resultItr < result.length; resultItr++) {
            bufferedWriter.write(String.valueOf(result[resultItr]));

            if (resultItr != result.length - 1) {
                bufferedWriter.write("\n");
            }
        }

        bufferedWriter.newLine();

        bufferedWriter.close();

        scanner.close();
    }
}








In C :




#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>


#define floor_log2_X86(self) (__builtin_clz(self) ^ 31U)
#define floor_log2 floor_log2_X86

void heap_sort(unsigned *self, unsigned *weights, unsigned length) {
    unsigned
        at = length >> 1,
        member,
        node;

    for (self--; at; self[node >> 1] = member) {
        member = self[at];

        for (node = at-- << 1; node <= length; node <<= 1) {
            node |= (node < length) && (weights[self[node]] < weights[self[node | 1]]);
            if (weights[self[node]] < weights[member])
                break ;
            self[node >> 1] = self[node];
        }
    }
    for (; length > 1; self[at >> 1] = member) {
        member = self[length];
        self[length--] = self[1];

        for (at = 2; at <= length; at <<= 1) {
            at |= (at < length) && (weights[self[at]] < weights[self[at | 1]]);
            if (weights[self[at]] < weights[member])
                break ;
            self[at >> 1] = self[at];
        }
    }
}

void compress(unsigned length, unsigned values[length]) {
    unsigned
        at,
        order[length];

    unsigned long sum = 0x0000000100000000UL;
    for (at = 0; at < (length >> 1); sum += 0x0000000200000002UL)
        ((unsigned long *)order)[at++] = sum;
    order[length - 1] = length - 1;

    heap_sort(order, values, length);

    unsigned roots[length], seen = 1, max = 0, others;
    for (roots[at = 0] = -1U; at < length; roots[seen++] = at - 1) {
        for (others = at; (at < length) && values[order[at]] == values[order[others]]; at++);

        if (max < (at - others))
            max = (at - others);
    }

    unsigned
        indices[max + 1],
        ranks[seen];

    memset(indices, 0, sizeof(indices));
    for (at = 0; ++at < seen; indices[roots[at] - roots[at - 1]]++);
    for (at = max; at--; indices[at] += indices[at + 1]);
    for (at = seen; --at; ranks[--indices[roots[at] - roots[at - 1]]] = at);

    for (; at < (seen - 1); at++)
        for (others = roots[ranks[at] - 1]; ++others <= roots[ranks[at]]; values[order[others]] = at);
}


static inline unsigned nearest_common_ancestor(
    unsigned depth,
    unsigned base_cnt,
    unsigned vertex_cnt,
    unsigned base_ids[vertex_cnt],
    unsigned bases[base_cnt][depth],
    unsigned char depths[base_cnt],
    unsigned weights[vertex_cnt],
    unsigned lower,
    unsigned upper
) {
    if (upper < (lower + weights[lower]))
        return lower;

    if (depths[upper] > depths[lower])
        upper = bases[base_ids[upper]][depths[upper] - depths[lower] - 1];

    if (upper < lower)
        return upper;

    unsigned *others = bases[base_ids[upper]];
    for (; depth > 1; depth >>= 1)
        if (others[depth >> 1] > lower) {
            others += depth >> 1;
            depth += depth & 1U;
        }

    return others[others[0] > lower];
}

typedef union {
    unsigned long packd;
    struct {
        int low, high;
    };
} range_t;

typedef struct {
    unsigned
        *members,
        *colors,
        *indices,
        *locations;
} colored_tree_t;


unsigned long query_all(colored_tree_t *self, unsigned at, range_t other) {
    unsigned
        color = self->colors[at],
        length = self->indices[color + 1] - self->indices[color],
        *base = &self->members[self->indices[color]];

    if (other.high < base[0] || other.low > base[length - 1])
        return 0;

    if (self->colors[other.low] != color) {
        if (at < other.low) {
            base += self->locations[at] - self->indices[color];
            length = self->indices[color + 1] - self->locations[at];
        } else
            length = self->locations[at] - self->indices[color]; // at > other.low

        for (; length > 1; length >>= 1)
            if (base[length >> 1] < other.low) {
                base += length >> 1;
                length += length & 1;
            }

        base += (base[0] < other.low);
    } else
        base += (self->locations[other.low] - self->indices[color]);

    if (base[0] > other.high)
        return 0;

    unsigned *ceil;
    if (self->colors[other.high] != color) {
        ceil = (at > base[0] && at < other.high) ? &self->members[self->locations[at]] : base;

        for (length = self->indices[color + 1] - self->locations[ceil[0]]; length > 1; length >>= 1)
            if (ceil[length >> 1] <= other.high) {
                ceil += length >> 1;
                length += length & 1;
            }

        ceil -= (ceil[0] > other.high);
    } else
        ceil = &self->members[self->locations[other.high]];


    return ceil - base + 1 - (at >= other.low && at <= other.high);
}

unsigned long count_pairs(
    unsigned cnt,
    unsigned length,
    unsigned long pairs[cnt][cnt],
    unsigned *overlapping,
    colored_tree_t *tree,
    range_t self,
    range_t other
) {
    unsigned long count = 0;
    for (; (self.low % length) && (self.low <= self.high); count += query_all(tree, self.low++, other));
    for (; ((self.high + 1) % length) && (self.low <= self.high); count += query_all(tree, self.high--, other));

    if (self.low <= self.high) {
        for (; (other.low % length) && (other.low <= other.high); count += query_all(tree, other.low++, self));
        for (; ((other.high + 1) % length) && (other.low <= other.high); count += query_all(tree, other.high--, self));

        if (other.low <= other.high) {
            self.low   /= length;
            self.high  /= length;
            other.low  /= length;
            other.high /= length;

            if (self.low > other.low) {
                self.packd  ^= other.packd;
                other.packd ^= self.packd;
                self.packd  ^= other.packd;
            }

            unsigned high = (self.high < other.low) ? self.high : (other.low - 1);

            count +=
                pairs[high][other.high]
                    - pairs[high][other.low - 1UL]
                    - pairs[self.low - 1UL][other.high]
                    + pairs[self.low - 1UL][other.low - 1UL];

            self.low = high + 1;

            if (self.high > other.high) {
                self.packd  ^= other.packd;
                other.packd ^= self.packd;
                self.packd  ^= other.packd;
            }

            if (self.low <= self.high)
                count +=
                    (overlapping[self.high] - overlapping[self.low - 1UL])
                        + ((
                        pairs[self.high][self.high]
                            - pairs[self.high][self.low - 1UL]
                            - pairs[self.low - 1UL][self.high]
                            + pairs[self.low - 1UL][self.low - 1UL]
                    ) << 1) + (
                        pairs[self.high][other.high]
                            - pairs[self.high][self.high]
                            - pairs[self.low - 1UL][other.high]
                            + pairs[self.low - 1UL][self.high]
                    );
        }
    }

    return count;
}

int main() {
    unsigned at, vertex_cnt;
    unsigned short query_cnt;
    scanf("%u %hu", &vertex_cnt, &query_cnt);

    unsigned colors[vertex_cnt + 1];
    for (at = 0; at < vertex_cnt; scanf("%u", &colors[at++]));
    colors[at] = 0xFFFFFFFFU;
    compress(at + 1, colors);

    unsigned ancestors[at + 1];
    {
        unsigned ancestor, descendant;
        for (memset(ancestors, 0xFFU, sizeof(ancestors)); --at; ancestors[descendant] = ancestor) {
            scanf("%u %u", &ancestor, &descendant);
            --ancestor;
            if (ancestors[--descendant] != 0xFFFFFFFFU) {
                unsigned root = descendant, next;
                for (; ancestor != 0xFFFFFFFFU; ancestor = next) {
                    next = ancestors[ancestor];
                    ancestors[ancestor] = root;
                    root = ancestor;
                }
                for (; ancestors[descendant] != 0xFFFFFFFFU; descendant = next) {
                    next = ancestors[descendant];
                    ancestors[descendant] = ancestor;
                    ancestor = descendant;
                }
            }
        }

        for (ancestor = 0xFFFFFFFFU; at != 0xFFFFFFFFU; at = descendant) {
            descendant = ancestors[at];
            ancestors[at] = ancestor;
            ancestor = at;
        }
    }

    unsigned
        others,
        ids[vertex_cnt + 1],
        weights[vertex_cnt],
        bases[vertex_cnt + 1],
        history[vertex_cnt];

    unsigned char
        base_depths[vertex_cnt],
        dist = 0;

    {
        unsigned
            history[vertex_cnt],
            indices[vertex_cnt + 1],
            descendants[vertex_cnt];

        memset(indices, 0, sizeof(indices));
        for (ancestors[vertex_cnt] = (at = vertex_cnt); at; indices[ancestors[at--]]++);
        for (; ++at <= vertex_cnt; indices[at] += indices[at - 1]);
        for (; --at; descendants[--indices[ancestors[at]]] = at);

        history[0] = 0;
        memset(weights, 0, sizeof(weights));
        for (at = 1; at--; )
            if (weights[history[at]])
                for (others = indices[history[at]];
                     others < indices[history[at] + 1];
                     weights[history[at]] += weights[descendants[others++]]);
            else {
                weights[history[at]] = 1;
                memcpy(
                    &history[at + 1],
                    &descendants[indices[history[at]]],
                    (indices[history[at] + 1] - indices[history[at]]) * sizeof(descendants[0])
                );
                at += indices[history[at] + 1] - indices[history[at]] + 1;
            }

        unsigned
            orig_ancestors[vertex_cnt + 1],
            orig_colors[vertex_cnt + 1],
            orig_weights[vertex_cnt];

        memcpy(orig_ancestors, ancestors, sizeof(ancestors));
        memcpy(orig_weights, weights, sizeof(weights));
        memcpy(orig_colors, colors, sizeof(colors));

        base_depths[0] = (bases[0] = (ids[0] = 0));
        bases[vertex_cnt] = (ids[vertex_cnt] = vertex_cnt);
        for (at = 1; at--;) {
            unsigned
                id = ids[history[at]],
                base = bases[id++],
                branches = indices[history[at] + 1] - indices[history[at]];

            heap_sort(&descendants[indices[history[at]]], orig_weights, branches);
            memcpy(&history[at], &descendants[indices[history[at]]], branches * sizeof(descendants[0]));

            for (others = (at += branches); branches--; base = id) {
                ids[history[--others]] = id;

                ancestors[id] = ids[orig_ancestors[history[others]]];
                weights[id] = orig_weights[history[others]];
                colors[id] = orig_colors[history[others]];

                bases[id] = base;
                base_depths[id] = base_depths[ancestors[id]] + (base == id);

                if (dist < base_depths[id])
                    dist = base_depths[id];

                id += weights[id];
            }
        }
    }

    unsigned base_ids[vertex_cnt + 1];
    for (base_ids[0] = (others = (at = 0)); others < vertex_cnt; base_ids[others] = base_ids[at] + 1)
        for (at = others; bases[at] == bases[others]; base_ids[others++] = base_ids[at]);

    unsigned ancestral_bases[base_ids[vertex_cnt]][dist];
    for (ancestors[0] = 0; others--; ancestral_bases[base_ids[others]][0] = ancestors[others]);
    while ((++others + 1) < dist)
        for (at = 0; ++at < base_ids[vertex_cnt];
             ancestral_bases[at][others + 1] = ancestors[bases[ancestral_bases[at][others]]]);

    unsigned
        indexed_colors[colors[vertex_cnt] + 2],
        members[vertex_cnt + 1];

    memset(indexed_colors, 0, sizeof(indexed_colors));
    for (at = vertex_cnt + 1; at--; indexed_colors[colors[at]]++);
    for (; ++at < colors[vertex_cnt]; indexed_colors[at + 1] += indexed_colors[at]);
    for (at = vertex_cnt + 1; at--; members[--indexed_colors[colors[at]]] = at);
    indexed_colors[colors[vertex_cnt] + 1] = indexed_colors[colors[vertex_cnt]];

    unsigned
        levels = floor_log2(vertex_cnt) + 1,
        block_cnt = (vertex_cnt / levels) + 1,
        locations[vertex_cnt + 1],
        overlapping[block_cnt];

    unsigned long (*pairs)[block_cnt][block_cnt] = calloc(
        (1 + block_cnt) * (1 + block_cnt),
        sizeof(pairs[0][0][0])
    );
    pairs = (void *)&pairs[0][1][1];

    for (at = vertex_cnt + 1; at--; locations[members[at]] = at);

    memset(overlapping, 0, sizeof(overlapping));
    for (at = 0; (indexed_colors[at + 1] - indexed_colors[at]) > 1; at++) {
        others = indexed_colors[at];

        unsigned
            block_bases[indexed_colors[at + 1] - others + 1],
            cnt = 1;

        for (block_bases[0] = members[others]; at == colors[members[++others]]; )
            if ((members[others] / levels) != (block_bases[cnt - 1] / levels))
                block_bases[cnt++] = members[others];

        block_bases[cnt] = members[others];
        for (others = 0; others < cnt; others++) {
            unsigned long density = locations[block_bases[others + 1]] - locations[block_bases[others]];
            overlapping[block_bases[others] / levels] += density * (density - 1);

            unsigned block = others;
            for (; ++block < cnt; pairs[0][block_bases[others] / levels][block_bases[block] / levels]
                += density * (locations[block_bases[block + 1]] - locations[block_bases[block]]));
        }
    }

    for (at = 0; ++at < block_cnt; overlapping[at] += overlapping[at - 1])
        pairs[0][0][at] += pairs[0][0][at - 1];

    for (at = 0; ++at < block_cnt; )
        for (others = 0; ++others < block_cnt; pairs[0][at][others] += pairs[0][at][others - 1]);

    for (at = 0; ++at < block_cnt; )
        for (others = 0; others < block_cnt; others++)
            pairs[0][at][others] += pairs[0][at - 1][others];

    colored_tree_t *tree = &(colored_tree_t) {
        .members = members,
        .colors = colors,
        .indices = indexed_colors,
        .locations = locations
    };


    while (query_cnt--) {
        range_t left, right;
        scanf("%u %u %u %u", &left.low, &left.high, &right.low, &right.high);
        left.packd -= 0x0000000100000001UL;
        right.packd -= 0x0000000100000001UL;

        left.low = ids[left.low];
        left.high = ids[left.high];

        right.low = ids[right.low];
        right.high = ids[right.high];

        if (left.high < left.low)
            left.packd = (left.packd << 32) | (left.packd >> 32);

        if (right.high < right.low)
            right.packd = (right.packd << 32) | (right.packd >> 32);

        if (right.high < left.low) {
            left.packd  ^= right.packd;
            right.packd ^= left.packd;
            left.packd  ^= right.packd;
        }

        struct {
            range_t members[32];
            unsigned cnt;
        }
            a = {.cnt = 0},
            b = {.cnt = 0};

        unsigned common = nearest_common_ancestor(
            dist, base_ids[vertex_cnt], vertex_cnt,
            base_ids, ancestral_bases,
            base_depths, weights,
            left.low, left.high
        );

        for (at = left.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
            a.members[a.cnt++].packd = bases[at] | ((unsigned long)at << 32);

        for (others = left.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
            a.members[a.cnt++].packd = bases[others] | ((unsigned long)others << 32);

        a.members[a.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);

        common = nearest_common_ancestor(
            dist, base_ids[vertex_cnt], vertex_cnt,
            base_ids, ancestral_bases,
            base_depths, weights,
            right.low, right.high
        );

        for (at = right.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
            b.members[b.cnt++].packd = bases[at] | ((unsigned long)at << 32);

        for (others = right.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
            b.members[b.cnt++].packd = bases[others] | ((unsigned long)others << 32);

        b.members[b.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);

        unsigned long total = 0;
        for (at = 0; at < a.cnt; at++)
            for (others = 0; others < b.cnt;
                 total += count_pairs(
                     block_cnt, levels, pairs[0], overlapping, tree,
                     a.members[at], b.members[others++]
                 )
            );

        printf("%lu\n", total);
    }

    return 0;
}








In Python3 :





import sys
from filecmp import cmp
from os import linesep
from time import time
from collections import Counter

memoized_BIT_prevs = [list() for _ in range(10**5+1)]
for i in range(1,10**5+1):
    next = i + (i & -i)
    if next >= len(memoized_BIT_prevs):
        continue
    memoized_BIT_prevs[next].append(i)

class UF(object):
    __slots__=['uf','ranks']
    def __init__(self):
        self.uf = {} #keeps union-find structure
        self.ranks = {}
    def UFadd(self,a):
        self.uf[a] = a #add curr to union-find
        self.ranks[a] = 0
    def UFfind(self,a):
        uf = self.uf
        curr = a
        while curr != uf[curr]:
            next = uf[uf[curr]] #does not fully path compress
            uf[curr]=next
            curr = next
        return curr
    def UFcombine(self,a,b):
        uf = self.uf
        ranks = self.ranks
        a_top = self.UFfind(a)
        b_top = self.UFfind(b)
        rank_a = ranks[a_top]
        rank_b = ranks[b_top]
        if rank_a < rank_b:
            uf[a_top] = b_top
        elif rank_a > rank_b:
            uf[b_top] = a_top
        else:
            uf[b_top] = a_top
            ranks[a_top] += 1

class mycounter(object):
    __slots__ = ['c']
    def __init__(self,key = None):
        if key is None:
            self.c = {}
        else:
            self.c = {key:1}
    def inc(self,key):
        self.c[key] = self.c.get(key,0) + 1
    def addto(self,other):
        for key,otherval in other.c.items():
            self.c[key] = self.c.get(key,0) + otherval
    def subfrom(self,other,mult = 1): #subtract other from self, never neg
        for key,otherval in other.c.items():
            assert key in self.c
            self.c[key] = self.c[key] - mult*otherval
    def innerprodwith(self,other):
        X = self.c
        Y = other.c
        if len(X) > len(Y):
            X,Y = Y,X
        return sum(X[i]*Y[i] for i in X if i in Y)

class mycounter2(object):
    __slots__ = ['q','doubleneg']
    def __init__(self):
        self.q = []
        self.doubleneg = None #will be list
    def inc(self,key):
        self.q.append(key)
    def addto(self,other):
        self.q.extend(other.q)
    def subfrom(self,other,mult = 1): #subtract other from self, never neg
        assert self.doubleneg is None and other.doubleneg is None
        self.doubleneg = other.q
    def innerprodwith(self,other):
        X = self
        Y = other
        if len(X.q) > len(Y.q):
            X,Y = Y,X
        Xq = X.q
        Yq = Y.q
        Xn = X.doubleneg
        Yn = Y.doubleneg
        S = set(Xq)
        S.update(Xn)
        return sum((Xq.count(s)-2*Xn.count(s))*(Yq.count(s)-2*Yn.count(s)) for s in S)

class geodcounter(object):
    __slots__ = ['size','C','above','below','values','lca','dists']
    def __init__(self,neb,root,vals,pairs):
        self.size = len(neb)
        self.C = [mycounter() for _ in range(self.size)]
        self.above = [None]*self.size
        self.below = [list() for _ in range(self.size)]
        self.values = vals
        self.lca = {}
        self.dists = [None]*self.size

        above = self.above
        below = self.below
        lca = self.lca
        dists = self.dists

        geod = [None] #one-indexed to fit BIT
        gpush = geod.append
        gpop = geod.pop

        height = 0
        traverse_stack = [(root,None)]
        tpush = traverse_stack.append
        tpop  = traverse_stack.pop
        visited = set()

        ancestors = {}
        lca_done = set()
        uf = UF()
        while(traverse_stack):
            curr,parent = tpop()
            if curr is None:
                break
            if curr not in visited:
                dists[curr] = height
                height += 1

                aboveht = height - (height & -height)
                above[curr] = geod[aboveht]
                for nextht in memoized_BIT_prevs[height]:
                    below[geod[nextht]].append(curr)

                uf.UFadd(curr)
                visited.add(curr)
                tpush((curr,parent))
                gpush(curr)
                ancestors[curr] = curr
                for child in neb[curr]:
                    if child in visited:
                        continue
                    tpush((child,curr))
            else:
                gcurr = gpop()
                assert gcurr == curr
                # tho self.below not complete yet,
                # it is for subtree rooted at curr, so OK to call notice
                self.notice(curr)
                height -= 1

                #this portion implements Tarjan's LCA alg
                lca_done.add(curr)
                for v in pairs[curr]:
                    if v in lca_done:
                        rep_v = uf.UFfind(v)
                        anc = ancestors[rep_v]
                        assert (curr,v) not in lca and (v,curr) not in lca
                        lca[(curr,v)] = anc
                        lca[(v,curr)] = anc
                if parent is not None:
                    uf.UFcombine(curr,parent)
                    ancestors[uf.UFfind(curr)] = parent
    def notice(self,node):
        val = self.values[node]
        C = self.C
        below = self.below
        stack = [node]
        pop = stack.pop
        extend = stack.extend
        while stack:
            curr = pop()
            C[curr].inc(val)
            extend(below[curr]) #note that the ORDER we visit them doesn't matter
    def find(self,node):
        acc = mycounter()
        C = self.C
        above = self.above
        while node is not None:
            acc.addto(C[node])
            node = above[node]
        return acc
    def get_incidence(self,a,b):
        l = self.lca[(a,b)]
        find_a = self.find(a)
        #find_b = self.find(b)
        #delta = mycounter(self.values[l])
        #pos_part = find_a + find_b + delta
        find_a.addto(self.find(b))
        find_a.inc(self.values[l])
        find_a.subfrom(self.find(l),2)
        #answer = pos_part - find_l - find_l
        #order matters b/c of how subtraction of counters works
        return find_a
    def size_intersection(self,a,b,c,d):
        prs = ((a,b),(c,d),(a,c),(a,d),(b,c),(b,d))
        verts = list(self.lca[x] for x in prs)
        key1,key2 = verts[0],verts[1]
        check = Counter(verts)
        if check[key1] == 1 or check[key2] == 1:
            return 0 #empty intersection
        most = check.most_common()
        counts = list(freq for val,freq in most)
        if counts == [6] or counts == [3,3]:
            return 1 #intersect in one point
        elif counts == [5,1] or counts == [3,2,1]: #root meets at or beyond endpt of intersection
            close = most[-2][0]
            far = most[-1][0]
            return (self.dists[far] - self.dists[close])+1 #size, not length
        elif counts == [4,1,1]:
            left = most[1][0]
            right = most[2][0]
            mid = most[0][0]
            return (self.dists[left] + self.dists[right] - 2*self.dists[mid])+ 1 #size, not length
        else:
            raise RuntimeError
    def process_commands(self,commands):
        measure_int = self.size_intersection
        get_incidence = self.get_incidence
        timer = 0
        for w,x,y,z in commands:
            A = get_incidence(w,x)
            B = get_incidence(y,z)
            innerprod = A.innerprodwith(B)
            len_inter = measure_int(w,x,y,z)
            ans = innerprod - len_inter
            assert ans >= 0
            yield ans

def test(num = None):
    if num is None:
        inp = sys.stdin
        out = sys.stdout
    else:
        inp = open('./input'+num+'.txt')
        out = open('./myoutput'+num+'.txt','w')

    start_time = time()
    N,Q = tuple(map(int,inp.readline().strip().split(' ')))
    vals = [0] #one-indexed so c_i = C[i]
    vals.extend(map(int,inp.readline().strip().split(' ')))
    neb = [list() for x in range(N+1)]
    for _ in range(N-1):
        a, b = tuple(map(int,inp.readline().strip().split(' ')))
        neb[a].append(b)
        neb[b].append(a)

    pairs = [set() for _ in range(N+1)]
    commands = []
    for _ in range(Q):
        w,x,y,z = tuple(map(int,inp.readline().strip().split(' ')))
        pairs[w].update((x,y,z))
        pairs[x].update((w,y,z))
        pairs[y].update((w,x,z))
        pairs[z].update((w,x,y))
        commands.append((w,x,y,z))
    count_geod = geodcounter(neb,1,vals,pairs) #make 1 root arbitrarily

    for answer in count_geod.process_commands(commands):#,desireds):
        print(answer,file=out)
    end_time = time()

    if num is not None:
        inp.close()
        if num != '00':
            remove_chars = len(linesep)
            out.truncate(out.tell() - remove_chars) #strip trailing newline
        out.close()

        succeeded = cmp('myoutput'+num+'.txt','output'+num+'.txt')
        outcome = "  Success" if succeeded else "  Failure"
        print("#"+num+outcome,"in {:f}s".format(end_time-start_time))

test()
                        








View More Similar Problems

Get Node Value

This challenge is part of a tutorial track by MyCodeSchool Given a pointer to the head of a linked list and a specific position, determine the data value at that position. Count backwards from the tail node. The tail is at postion 0, its parent is at 1 and so on. Example head refers to 3 -> 2 -> 1 -> 0 -> NULL positionFromTail = 2 Each of the data values matches its distance from the t

View Solution →

Delete duplicate-value nodes from a sorted linked list

This challenge is part of a tutorial track by MyCodeSchool You are given the pointer to the head node of a sorted linked list, where the data in the nodes is in ascending order. Delete nodes and return a sorted list with each distinct value in the original list. The given head pointer may be null indicating that the list is empty. Example head refers to the first node in the list 1 -> 2 -

View Solution →

Cycle Detection

A linked list is said to contain a cycle if any node is visited more than once while traversing the list. Given a pointer to the head of a linked list, determine if it contains a cycle. If it does, return 1. Otherwise, return 0. Example head refers 1 -> 2 -> 3 -> NUL The numbers shown are the node numbers, not their data values. There is no cycle in this list so return 0. head refer

View Solution →

Find Merge Point of Two Lists

This challenge is part of a tutorial track by MyCodeSchool Given pointers to the head nodes of 2 linked lists that merge together at some point, find the node where the two lists merge. The merge point is where both lists point to the same node, i.e. they reference the same memory location. It is guaranteed that the two head nodes will be different, and neither will be NULL. If the lists share

View Solution →

Inserting a Node Into a Sorted Doubly Linked List

Given a reference to the head of a doubly-linked list and an integer ,data , create a new DoublyLinkedListNode object having data value data and insert it at the proper location to maintain the sort. Example head refers to the list 1 <-> 2 <-> 4 - > NULL. data = 3 Return a reference to the new list: 1 <-> 2 <-> 4 - > NULL , Function Description Complete the sortedInsert function

View Solution →

Reverse a doubly linked list

This challenge is part of a tutorial track by MyCodeSchool Given the pointer to the head node of a doubly linked list, reverse the order of the nodes in place. That is, change the next and prev pointers of the nodes so that the direction of the list is reversed. Return a reference to the head node of the reversed list. Note: The head node might be NULL to indicate that the list is empty.

View Solution →