Counting On a Tree


Problem Statement :


Taylor loves trees, and this new challenge has him stumped!

Consider a tree, t, consisting of n nodes. Each node is numbered from 1 to n, and each node i has an integer, ci, attached to it.

A query on tree t takes the form w x y z. To process a query, you must print the count of ordered pairs of integers ( i , j ) such that the following four conditions are all satisfied:

 the path from node  to node .
 path from node  to node .
Given t and q queries, process each query in order, printing the pair count for each query on a new line.

Input Format

The first line contains two space-separated integers describing the respective values of n (the number of nodes) and q (the number of queries).
The second line contains  space-separated integers describing the respective values of each node (i.e., c1 , c2, . . . , cn ).
Each of the n - 1 subsequent lines contains two space-separated integers, u  and v , defining a bidirectional edge between nodes u and v.
Each of the q subsequent lines contains a w x y z query, defined above.


Constraints

1   <=  n   <=  10^5
1   <=  q   <=  50000
1   <=  ci   <=  10^9
1   <=  u, v , w, x, y , z  <= n

Scoring for this problem is Binary, that means you have to pass all the test cases to get a positive score.

Output Format

For each query, print the count of ordered pairs of integers satisfying the four given conditions on a new line.



Solution :



title-img


                            Solution in C :

In  C++ :





#include <bits/stdc++.h>

using namespace std;

const int kMaxN = 300005;

vector<int> G[kMaxN];
int Color[kMaxN], Linear[kMaxN], Beg[kMaxN],
    End[kMaxN], Depth[kMaxN];
long long Answer[kMaxN];

int n, colors, timer;

int Parent[20][kMaxN];

void DFS(int node) {
    Beg[node] = ++timer;
    Linear[timer] = +Color[node];

    for(auto vec : G[node]) {
        if(!Beg[vec]) {
            Parent[0][vec] = node;
            Depth[vec] = Depth[node] + 1;

            for(int i = 1; Parent[i - 1][vec]; ++i) {
                Parent[i][vec] = Parent[i - 1][Parent[i - 1][vec]];
            }

            DFS(vec);
        }
    }

    End[node] = ++timer;
    Linear[timer] = -Color[node];
}

int LCA(int a, int b) {
    if(Depth[a] < Depth[b])
        swap(a, b);

    int d = Depth[a] - Depth[b];
    for(int i = 0; d; ++i, d /= 2)
        if(d % 2)
            a = Parent[i][a];

    if(a == b)
        return a;

    for(int i = 19; i >= 0; --i)
        if(Parent[i][a] != Parent[i][b])
            a = Parent[i][a], b = Parent[i][b];

    assert(Parent[0][a] == Parent[0][b]);
    return Parent[0][a];
}


struct Query {
    int a, b, i, sgn;
};
vector<Query> Queries;
void AddQuery(int a, int b, int i, int sgn) {
    if(a == 0 || b == 0) return;
    Queries.push_back(Query {a, b, i, sgn});
    //cerr << "Added: " << a << " " << b << " at ind " << i << " with sign " << sgn << '\n';
}
void AddQuery(int a, int b, int c, int d, int i) {
    AddQuery(a, c, i, 1);
    AddQuery(b, c, i, -1);
    AddQuery(a, d, i, -1);
    AddQuery(b, d, i, 1);
}

int Count[2][kMaxN];
long long global_ans;
void Add(int at, int col) {
    int pos = abs(col);
    int delta = col / pos;
    
    global_ans -= 1LL * Count[0][pos] * Count[1][pos];
    Count[at][pos] += delta;
    global_ans += 1LL * Count[0][pos] * Count[1][pos];
}

void SolveQueries() {
    for(auto &q : Queries) {
        Answer[q.i] -= q.sgn * (1 + Depth[LCA(q.a, q.b)]);
        
        q.a = Beg[q.a];
        q.b = Beg[q.b];
        
        if(q.a > q.b) 
            swap(q.a, q.b);
    }
    
    const int kMagic = 256;
    sort(Queries.begin(), Queries.end(), [](Query &a, Query &b) {
        if(a.a / kMagic == b.a / kMagic)
            return a.b < b.b;
        return a.a < b.a;
    });
    
    int b = 0, e = 0;
    for(auto &q : Queries) {
        while(b < q.a) Add(0, Linear[++b]);
        while(e < q.b) Add(1, Linear[++e]);
        while(b > q.a) Add(0, -Linear[b--]);
        while(e > q.b) Add(1, -Linear[e--]);
        
        Answer[q.i] += global_ans * q.sgn;
    }
    
}

int main() {
    int q;
    cin >> n >> q;

    map<int, int> normMap;
    for(int i = 1; i <= n; ++i) {
        cin >> Color[i];
        auto &norm = normMap[Color[i]];
        if(norm == 0) 
            norm = ++colors;
        Color[i] = norm;
    }

    for(int i = 2; i <= n; ++i) {
        int a, b;
        cin >> a >> b;
        G[a].push_back(b);
        G[b].push_back(a);
    }

    DFS(1);
    Depth[0] = -1;

    
/*    
    for(int i = 1; i <= timer; ++i)
        cerr << Linear[i] << " ";
    cerr << endl;
*/
    
    for(int i = 1; i <= q; ++i) {
        int a1, b1, a2, b2;
        cin >> a1 >> b1 >> a2 >> b2;
        
        int l1 = LCA(a1, b1);
        int l2 = LCA(a2, b2);

        AddQuery(a1, Parent[0][l1], a2, Parent[0][l2], i);
        AddQuery(a1, Parent[0][l1], b2, l2, i);
        AddQuery(b1, l1, a2, Parent[0][l2], i);
        AddQuery(b1, l1, b2, l2, i);

    }

    SolveQueries();

    for(int i = 1; i <= q; ++i)
        cout << Answer[i] << "\n";
    cout << endl;

    return 0;
}







In Java :






import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;

public class Solution {

    // Complete the solve function below.
    static int[] solve(int[] c, int[][] tree, int[][] queries) {


    }

    private static final Scanner scanner = new Scanner(System.in);

    public static void main(String[] args) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

        String[] nq = scanner.nextLine().split(" ");

        int n = Integer.parseInt(nq[0]);

        int q = Integer.parseInt(nq[1]);

        int[] c = new int[n];

        String[] cItems = scanner.nextLine().split(" ");
        scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

        for (int cItr = 0; cItr < n; cItr++) {
            int cItem = Integer.parseInt(cItems[cItr]);
            c[cItr] = cItem;
        }

        int[][] tree = new int[n-1][2];

        for (int treeRowItr = 0; treeRowItr < n-1; treeRowItr++) {
            String[] treeRowItems = scanner.nextLine().split(" ");
            scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

            for (int treeColumnItr = 0; treeColumnItr < 2; treeColumnItr++) {
                int treeItem = Integer.parseInt(treeRowItems[treeColumnItr]);
                tree[treeRowItr][treeColumnItr] = treeItem;
            }
        }

        int[][] queries = new int[q][4];

        for (int queriesRowItr = 0; queriesRowItr < q; queriesRowItr++) {
            String[] queriesRowItems = scanner.nextLine().split(" ");
            scanner.skip("(\r\n|[\n\r\u2028\u2029\u0085])?");

            for (int queriesColumnItr = 0; queriesColumnItr < 4; queriesColumnItr++) {
                int queriesItem = Integer.parseInt(queriesRowItems[queriesColumnItr]);
                queries[queriesRowItr][queriesColumnItr] = queriesItem;
            }
        }

        int[] result = solve(c, tree, queries);

        for (int resultItr = 0; resultItr < result.length; resultItr++) {
            bufferedWriter.write(String.valueOf(result[resultItr]));

            if (resultItr != result.length - 1) {
                bufferedWriter.write("\n");
            }
        }

        bufferedWriter.newLine();

        bufferedWriter.close();

        scanner.close();
    }
}








In C :




#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>


#define floor_log2_X86(self) (__builtin_clz(self) ^ 31U)
#define floor_log2 floor_log2_X86

void heap_sort(unsigned *self, unsigned *weights, unsigned length) {
    unsigned
        at = length >> 1,
        member,
        node;

    for (self--; at; self[node >> 1] = member) {
        member = self[at];

        for (node = at-- << 1; node <= length; node <<= 1) {
            node |= (node < length) && (weights[self[node]] < weights[self[node | 1]]);
            if (weights[self[node]] < weights[member])
                break ;
            self[node >> 1] = self[node];
        }
    }
    for (; length > 1; self[at >> 1] = member) {
        member = self[length];
        self[length--] = self[1];

        for (at = 2; at <= length; at <<= 1) {
            at |= (at < length) && (weights[self[at]] < weights[self[at | 1]]);
            if (weights[self[at]] < weights[member])
                break ;
            self[at >> 1] = self[at];
        }
    }
}

void compress(unsigned length, unsigned values[length]) {
    unsigned
        at,
        order[length];

    unsigned long sum = 0x0000000100000000UL;
    for (at = 0; at < (length >> 1); sum += 0x0000000200000002UL)
        ((unsigned long *)order)[at++] = sum;
    order[length - 1] = length - 1;

    heap_sort(order, values, length);

    unsigned roots[length], seen = 1, max = 0, others;
    for (roots[at = 0] = -1U; at < length; roots[seen++] = at - 1) {
        for (others = at; (at < length) && values[order[at]] == values[order[others]]; at++);

        if (max < (at - others))
            max = (at - others);
    }

    unsigned
        indices[max + 1],
        ranks[seen];

    memset(indices, 0, sizeof(indices));
    for (at = 0; ++at < seen; indices[roots[at] - roots[at - 1]]++);
    for (at = max; at--; indices[at] += indices[at + 1]);
    for (at = seen; --at; ranks[--indices[roots[at] - roots[at - 1]]] = at);

    for (; at < (seen - 1); at++)
        for (others = roots[ranks[at] - 1]; ++others <= roots[ranks[at]]; values[order[others]] = at);
}


static inline unsigned nearest_common_ancestor(
    unsigned depth,
    unsigned base_cnt,
    unsigned vertex_cnt,
    unsigned base_ids[vertex_cnt],
    unsigned bases[base_cnt][depth],
    unsigned char depths[base_cnt],
    unsigned weights[vertex_cnt],
    unsigned lower,
    unsigned upper
) {
    if (upper < (lower + weights[lower]))
        return lower;

    if (depths[upper] > depths[lower])
        upper = bases[base_ids[upper]][depths[upper] - depths[lower] - 1];

    if (upper < lower)
        return upper;

    unsigned *others = bases[base_ids[upper]];
    for (; depth > 1; depth >>= 1)
        if (others[depth >> 1] > lower) {
            others += depth >> 1;
            depth += depth & 1U;
        }

    return others[others[0] > lower];
}

typedef union {
    unsigned long packd;
    struct {
        int low, high;
    };
} range_t;

typedef struct {
    unsigned
        *members,
        *colors,
        *indices,
        *locations;
} colored_tree_t;


unsigned long query_all(colored_tree_t *self, unsigned at, range_t other) {
    unsigned
        color = self->colors[at],
        length = self->indices[color + 1] - self->indices[color],
        *base = &self->members[self->indices[color]];

    if (other.high < base[0] || other.low > base[length - 1])
        return 0;

    if (self->colors[other.low] != color) {
        if (at < other.low) {
            base += self->locations[at] - self->indices[color];
            length = self->indices[color + 1] - self->locations[at];
        } else
            length = self->locations[at] - self->indices[color]; // at > other.low

        for (; length > 1; length >>= 1)
            if (base[length >> 1] < other.low) {
                base += length >> 1;
                length += length & 1;
            }

        base += (base[0] < other.low);
    } else
        base += (self->locations[other.low] - self->indices[color]);

    if (base[0] > other.high)
        return 0;

    unsigned *ceil;
    if (self->colors[other.high] != color) {
        ceil = (at > base[0] && at < other.high) ? &self->members[self->locations[at]] : base;

        for (length = self->indices[color + 1] - self->locations[ceil[0]]; length > 1; length >>= 1)
            if (ceil[length >> 1] <= other.high) {
                ceil += length >> 1;
                length += length & 1;
            }

        ceil -= (ceil[0] > other.high);
    } else
        ceil = &self->members[self->locations[other.high]];


    return ceil - base + 1 - (at >= other.low && at <= other.high);
}

unsigned long count_pairs(
    unsigned cnt,
    unsigned length,
    unsigned long pairs[cnt][cnt],
    unsigned *overlapping,
    colored_tree_t *tree,
    range_t self,
    range_t other
) {
    unsigned long count = 0;
    for (; (self.low % length) && (self.low <= self.high); count += query_all(tree, self.low++, other));
    for (; ((self.high + 1) % length) && (self.low <= self.high); count += query_all(tree, self.high--, other));

    if (self.low <= self.high) {
        for (; (other.low % length) && (other.low <= other.high); count += query_all(tree, other.low++, self));
        for (; ((other.high + 1) % length) && (other.low <= other.high); count += query_all(tree, other.high--, self));

        if (other.low <= other.high) {
            self.low   /= length;
            self.high  /= length;
            other.low  /= length;
            other.high /= length;

            if (self.low > other.low) {
                self.packd  ^= other.packd;
                other.packd ^= self.packd;
                self.packd  ^= other.packd;
            }

            unsigned high = (self.high < other.low) ? self.high : (other.low - 1);

            count +=
                pairs[high][other.high]
                    - pairs[high][other.low - 1UL]
                    - pairs[self.low - 1UL][other.high]
                    + pairs[self.low - 1UL][other.low - 1UL];

            self.low = high + 1;

            if (self.high > other.high) {
                self.packd  ^= other.packd;
                other.packd ^= self.packd;
                self.packd  ^= other.packd;
            }

            if (self.low <= self.high)
                count +=
                    (overlapping[self.high] - overlapping[self.low - 1UL])
                        + ((
                        pairs[self.high][self.high]
                            - pairs[self.high][self.low - 1UL]
                            - pairs[self.low - 1UL][self.high]
                            + pairs[self.low - 1UL][self.low - 1UL]
                    ) << 1) + (
                        pairs[self.high][other.high]
                            - pairs[self.high][self.high]
                            - pairs[self.low - 1UL][other.high]
                            + pairs[self.low - 1UL][self.high]
                    );
        }
    }

    return count;
}

int main() {
    unsigned at, vertex_cnt;
    unsigned short query_cnt;
    scanf("%u %hu", &vertex_cnt, &query_cnt);

    unsigned colors[vertex_cnt + 1];
    for (at = 0; at < vertex_cnt; scanf("%u", &colors[at++]));
    colors[at] = 0xFFFFFFFFU;
    compress(at + 1, colors);

    unsigned ancestors[at + 1];
    {
        unsigned ancestor, descendant;
        for (memset(ancestors, 0xFFU, sizeof(ancestors)); --at; ancestors[descendant] = ancestor) {
            scanf("%u %u", &ancestor, &descendant);
            --ancestor;
            if (ancestors[--descendant] != 0xFFFFFFFFU) {
                unsigned root = descendant, next;
                for (; ancestor != 0xFFFFFFFFU; ancestor = next) {
                    next = ancestors[ancestor];
                    ancestors[ancestor] = root;
                    root = ancestor;
                }
                for (; ancestors[descendant] != 0xFFFFFFFFU; descendant = next) {
                    next = ancestors[descendant];
                    ancestors[descendant] = ancestor;
                    ancestor = descendant;
                }
            }
        }

        for (ancestor = 0xFFFFFFFFU; at != 0xFFFFFFFFU; at = descendant) {
            descendant = ancestors[at];
            ancestors[at] = ancestor;
            ancestor = at;
        }
    }

    unsigned
        others,
        ids[vertex_cnt + 1],
        weights[vertex_cnt],
        bases[vertex_cnt + 1],
        history[vertex_cnt];

    unsigned char
        base_depths[vertex_cnt],
        dist = 0;

    {
        unsigned
            history[vertex_cnt],
            indices[vertex_cnt + 1],
            descendants[vertex_cnt];

        memset(indices, 0, sizeof(indices));
        for (ancestors[vertex_cnt] = (at = vertex_cnt); at; indices[ancestors[at--]]++);
        for (; ++at <= vertex_cnt; indices[at] += indices[at - 1]);
        for (; --at; descendants[--indices[ancestors[at]]] = at);

        history[0] = 0;
        memset(weights, 0, sizeof(weights));
        for (at = 1; at--; )
            if (weights[history[at]])
                for (others = indices[history[at]];
                     others < indices[history[at] + 1];
                     weights[history[at]] += weights[descendants[others++]]);
            else {
                weights[history[at]] = 1;
                memcpy(
                    &history[at + 1],
                    &descendants[indices[history[at]]],
                    (indices[history[at] + 1] - indices[history[at]]) * sizeof(descendants[0])
                );
                at += indices[history[at] + 1] - indices[history[at]] + 1;
            }

        unsigned
            orig_ancestors[vertex_cnt + 1],
            orig_colors[vertex_cnt + 1],
            orig_weights[vertex_cnt];

        memcpy(orig_ancestors, ancestors, sizeof(ancestors));
        memcpy(orig_weights, weights, sizeof(weights));
        memcpy(orig_colors, colors, sizeof(colors));

        base_depths[0] = (bases[0] = (ids[0] = 0));
        bases[vertex_cnt] = (ids[vertex_cnt] = vertex_cnt);
        for (at = 1; at--;) {
            unsigned
                id = ids[history[at]],
                base = bases[id++],
                branches = indices[history[at] + 1] - indices[history[at]];

            heap_sort(&descendants[indices[history[at]]], orig_weights, branches);
            memcpy(&history[at], &descendants[indices[history[at]]], branches * sizeof(descendants[0]));

            for (others = (at += branches); branches--; base = id) {
                ids[history[--others]] = id;

                ancestors[id] = ids[orig_ancestors[history[others]]];
                weights[id] = orig_weights[history[others]];
                colors[id] = orig_colors[history[others]];

                bases[id] = base;
                base_depths[id] = base_depths[ancestors[id]] + (base == id);

                if (dist < base_depths[id])
                    dist = base_depths[id];

                id += weights[id];
            }
        }
    }

    unsigned base_ids[vertex_cnt + 1];
    for (base_ids[0] = (others = (at = 0)); others < vertex_cnt; base_ids[others] = base_ids[at] + 1)
        for (at = others; bases[at] == bases[others]; base_ids[others++] = base_ids[at]);

    unsigned ancestral_bases[base_ids[vertex_cnt]][dist];
    for (ancestors[0] = 0; others--; ancestral_bases[base_ids[others]][0] = ancestors[others]);
    while ((++others + 1) < dist)
        for (at = 0; ++at < base_ids[vertex_cnt];
             ancestral_bases[at][others + 1] = ancestors[bases[ancestral_bases[at][others]]]);

    unsigned
        indexed_colors[colors[vertex_cnt] + 2],
        members[vertex_cnt + 1];

    memset(indexed_colors, 0, sizeof(indexed_colors));
    for (at = vertex_cnt + 1; at--; indexed_colors[colors[at]]++);
    for (; ++at < colors[vertex_cnt]; indexed_colors[at + 1] += indexed_colors[at]);
    for (at = vertex_cnt + 1; at--; members[--indexed_colors[colors[at]]] = at);
    indexed_colors[colors[vertex_cnt] + 1] = indexed_colors[colors[vertex_cnt]];

    unsigned
        levels = floor_log2(vertex_cnt) + 1,
        block_cnt = (vertex_cnt / levels) + 1,
        locations[vertex_cnt + 1],
        overlapping[block_cnt];

    unsigned long (*pairs)[block_cnt][block_cnt] = calloc(
        (1 + block_cnt) * (1 + block_cnt),
        sizeof(pairs[0][0][0])
    );
    pairs = (void *)&pairs[0][1][1];

    for (at = vertex_cnt + 1; at--; locations[members[at]] = at);

    memset(overlapping, 0, sizeof(overlapping));
    for (at = 0; (indexed_colors[at + 1] - indexed_colors[at]) > 1; at++) {
        others = indexed_colors[at];

        unsigned
            block_bases[indexed_colors[at + 1] - others + 1],
            cnt = 1;

        for (block_bases[0] = members[others]; at == colors[members[++others]]; )
            if ((members[others] / levels) != (block_bases[cnt - 1] / levels))
                block_bases[cnt++] = members[others];

        block_bases[cnt] = members[others];
        for (others = 0; others < cnt; others++) {
            unsigned long density = locations[block_bases[others + 1]] - locations[block_bases[others]];
            overlapping[block_bases[others] / levels] += density * (density - 1);

            unsigned block = others;
            for (; ++block < cnt; pairs[0][block_bases[others] / levels][block_bases[block] / levels]
                += density * (locations[block_bases[block + 1]] - locations[block_bases[block]]));
        }
    }

    for (at = 0; ++at < block_cnt; overlapping[at] += overlapping[at - 1])
        pairs[0][0][at] += pairs[0][0][at - 1];

    for (at = 0; ++at < block_cnt; )
        for (others = 0; ++others < block_cnt; pairs[0][at][others] += pairs[0][at][others - 1]);

    for (at = 0; ++at < block_cnt; )
        for (others = 0; others < block_cnt; others++)
            pairs[0][at][others] += pairs[0][at - 1][others];

    colored_tree_t *tree = &(colored_tree_t) {
        .members = members,
        .colors = colors,
        .indices = indexed_colors,
        .locations = locations
    };


    while (query_cnt--) {
        range_t left, right;
        scanf("%u %u %u %u", &left.low, &left.high, &right.low, &right.high);
        left.packd -= 0x0000000100000001UL;
        right.packd -= 0x0000000100000001UL;

        left.low = ids[left.low];
        left.high = ids[left.high];

        right.low = ids[right.low];
        right.high = ids[right.high];

        if (left.high < left.low)
            left.packd = (left.packd << 32) | (left.packd >> 32);

        if (right.high < right.low)
            right.packd = (right.packd << 32) | (right.packd >> 32);

        if (right.high < left.low) {
            left.packd  ^= right.packd;
            right.packd ^= left.packd;
            left.packd  ^= right.packd;
        }

        struct {
            range_t members[32];
            unsigned cnt;
        }
            a = {.cnt = 0},
            b = {.cnt = 0};

        unsigned common = nearest_common_ancestor(
            dist, base_ids[vertex_cnt], vertex_cnt,
            base_ids, ancestral_bases,
            base_depths, weights,
            left.low, left.high
        );

        for (at = left.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
            a.members[a.cnt++].packd = bases[at] | ((unsigned long)at << 32);

        for (others = left.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
            a.members[a.cnt++].packd = bases[others] | ((unsigned long)others << 32);

        a.members[a.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);

        common = nearest_common_ancestor(
            dist, base_ids[vertex_cnt], vertex_cnt,
            base_ids, ancestral_bases,
            base_depths, weights,
            right.low, right.high
        );

        for (at = right.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
            b.members[b.cnt++].packd = bases[at] | ((unsigned long)at << 32);

        for (others = right.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
            b.members[b.cnt++].packd = bases[others] | ((unsigned long)others << 32);

        b.members[b.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);

        unsigned long total = 0;
        for (at = 0; at < a.cnt; at++)
            for (others = 0; others < b.cnt;
                 total += count_pairs(
                     block_cnt, levels, pairs[0], overlapping, tree,
                     a.members[at], b.members[others++]
                 )
            );

        printf("%lu\n", total);
    }

    return 0;
}








In Python3 :





import sys
from filecmp import cmp
from os import linesep
from time import time
from collections import Counter

memoized_BIT_prevs = [list() for _ in range(10**5+1)]
for i in range(1,10**5+1):
    next = i + (i & -i)
    if next >= len(memoized_BIT_prevs):
        continue
    memoized_BIT_prevs[next].append(i)

class UF(object):
    __slots__=['uf','ranks']
    def __init__(self):
        self.uf = {} #keeps union-find structure
        self.ranks = {}
    def UFadd(self,a):
        self.uf[a] = a #add curr to union-find
        self.ranks[a] = 0
    def UFfind(self,a):
        uf = self.uf
        curr = a
        while curr != uf[curr]:
            next = uf[uf[curr]] #does not fully path compress
            uf[curr]=next
            curr = next
        return curr
    def UFcombine(self,a,b):
        uf = self.uf
        ranks = self.ranks
        a_top = self.UFfind(a)
        b_top = self.UFfind(b)
        rank_a = ranks[a_top]
        rank_b = ranks[b_top]
        if rank_a < rank_b:
            uf[a_top] = b_top
        elif rank_a > rank_b:
            uf[b_top] = a_top
        else:
            uf[b_top] = a_top
            ranks[a_top] += 1

class mycounter(object):
    __slots__ = ['c']
    def __init__(self,key = None):
        if key is None:
            self.c = {}
        else:
            self.c = {key:1}
    def inc(self,key):
        self.c[key] = self.c.get(key,0) + 1
    def addto(self,other):
        for key,otherval in other.c.items():
            self.c[key] = self.c.get(key,0) + otherval
    def subfrom(self,other,mult = 1): #subtract other from self, never neg
        for key,otherval in other.c.items():
            assert key in self.c
            self.c[key] = self.c[key] - mult*otherval
    def innerprodwith(self,other):
        X = self.c
        Y = other.c
        if len(X) > len(Y):
            X,Y = Y,X
        return sum(X[i]*Y[i] for i in X if i in Y)

class mycounter2(object):
    __slots__ = ['q','doubleneg']
    def __init__(self):
        self.q = []
        self.doubleneg = None #will be list
    def inc(self,key):
        self.q.append(key)
    def addto(self,other):
        self.q.extend(other.q)
    def subfrom(self,other,mult = 1): #subtract other from self, never neg
        assert self.doubleneg is None and other.doubleneg is None
        self.doubleneg = other.q
    def innerprodwith(self,other):
        X = self
        Y = other
        if len(X.q) > len(Y.q):
            X,Y = Y,X
        Xq = X.q
        Yq = Y.q
        Xn = X.doubleneg
        Yn = Y.doubleneg
        S = set(Xq)
        S.update(Xn)
        return sum((Xq.count(s)-2*Xn.count(s))*(Yq.count(s)-2*Yn.count(s)) for s in S)

class geodcounter(object):
    __slots__ = ['size','C','above','below','values','lca','dists']
    def __init__(self,neb,root,vals,pairs):
        self.size = len(neb)
        self.C = [mycounter() for _ in range(self.size)]
        self.above = [None]*self.size
        self.below = [list() for _ in range(self.size)]
        self.values = vals
        self.lca = {}
        self.dists = [None]*self.size

        above = self.above
        below = self.below
        lca = self.lca
        dists = self.dists

        geod = [None] #one-indexed to fit BIT
        gpush = geod.append
        gpop = geod.pop

        height = 0
        traverse_stack = [(root,None)]
        tpush = traverse_stack.append
        tpop  = traverse_stack.pop
        visited = set()

        ancestors = {}
        lca_done = set()
        uf = UF()
        while(traverse_stack):
            curr,parent = tpop()
            if curr is None:
                break
            if curr not in visited:
                dists[curr] = height
                height += 1

                aboveht = height - (height & -height)
                above[curr] = geod[aboveht]
                for nextht in memoized_BIT_prevs[height]:
                    below[geod[nextht]].append(curr)

                uf.UFadd(curr)
                visited.add(curr)
                tpush((curr,parent))
                gpush(curr)
                ancestors[curr] = curr
                for child in neb[curr]:
                    if child in visited:
                        continue
                    tpush((child,curr))
            else:
                gcurr = gpop()
                assert gcurr == curr
                # tho self.below not complete yet,
                # it is for subtree rooted at curr, so OK to call notice
                self.notice(curr)
                height -= 1

                #this portion implements Tarjan's LCA alg
                lca_done.add(curr)
                for v in pairs[curr]:
                    if v in lca_done:
                        rep_v = uf.UFfind(v)
                        anc = ancestors[rep_v]
                        assert (curr,v) not in lca and (v,curr) not in lca
                        lca[(curr,v)] = anc
                        lca[(v,curr)] = anc
                if parent is not None:
                    uf.UFcombine(curr,parent)
                    ancestors[uf.UFfind(curr)] = parent
    def notice(self,node):
        val = self.values[node]
        C = self.C
        below = self.below
        stack = [node]
        pop = stack.pop
        extend = stack.extend
        while stack:
            curr = pop()
            C[curr].inc(val)
            extend(below[curr]) #note that the ORDER we visit them doesn't matter
    def find(self,node):
        acc = mycounter()
        C = self.C
        above = self.above
        while node is not None:
            acc.addto(C[node])
            node = above[node]
        return acc
    def get_incidence(self,a,b):
        l = self.lca[(a,b)]
        find_a = self.find(a)
        #find_b = self.find(b)
        #delta = mycounter(self.values[l])
        #pos_part = find_a + find_b + delta
        find_a.addto(self.find(b))
        find_a.inc(self.values[l])
        find_a.subfrom(self.find(l),2)
        #answer = pos_part - find_l - find_l
        #order matters b/c of how subtraction of counters works
        return find_a
    def size_intersection(self,a,b,c,d):
        prs = ((a,b),(c,d),(a,c),(a,d),(b,c),(b,d))
        verts = list(self.lca[x] for x in prs)
        key1,key2 = verts[0],verts[1]
        check = Counter(verts)
        if check[key1] == 1 or check[key2] == 1:
            return 0 #empty intersection
        most = check.most_common()
        counts = list(freq for val,freq in most)
        if counts == [6] or counts == [3,3]:
            return 1 #intersect in one point
        elif counts == [5,1] or counts == [3,2,1]: #root meets at or beyond endpt of intersection
            close = most[-2][0]
            far = most[-1][0]
            return (self.dists[far] - self.dists[close])+1 #size, not length
        elif counts == [4,1,1]:
            left = most[1][0]
            right = most[2][0]
            mid = most[0][0]
            return (self.dists[left] + self.dists[right] - 2*self.dists[mid])+ 1 #size, not length
        else:
            raise RuntimeError
    def process_commands(self,commands):
        measure_int = self.size_intersection
        get_incidence = self.get_incidence
        timer = 0
        for w,x,y,z in commands:
            A = get_incidence(w,x)
            B = get_incidence(y,z)
            innerprod = A.innerprodwith(B)
            len_inter = measure_int(w,x,y,z)
            ans = innerprod - len_inter
            assert ans >= 0
            yield ans

def test(num = None):
    if num is None:
        inp = sys.stdin
        out = sys.stdout
    else:
        inp = open('./input'+num+'.txt')
        out = open('./myoutput'+num+'.txt','w')

    start_time = time()
    N,Q = tuple(map(int,inp.readline().strip().split(' ')))
    vals = [0] #one-indexed so c_i = C[i]
    vals.extend(map(int,inp.readline().strip().split(' ')))
    neb = [list() for x in range(N+1)]
    for _ in range(N-1):
        a, b = tuple(map(int,inp.readline().strip().split(' ')))
        neb[a].append(b)
        neb[b].append(a)

    pairs = [set() for _ in range(N+1)]
    commands = []
    for _ in range(Q):
        w,x,y,z = tuple(map(int,inp.readline().strip().split(' ')))
        pairs[w].update((x,y,z))
        pairs[x].update((w,y,z))
        pairs[y].update((w,x,z))
        pairs[z].update((w,x,y))
        commands.append((w,x,y,z))
    count_geod = geodcounter(neb,1,vals,pairs) #make 1 root arbitrarily

    for answer in count_geod.process_commands(commands):#,desireds):
        print(answer,file=out)
    end_time = time()

    if num is not None:
        inp.close()
        if num != '00':
            remove_chars = len(linesep)
            out.truncate(out.tell() - remove_chars) #strip trailing newline
        out.close()

        succeeded = cmp('myoutput'+num+'.txt','output'+num+'.txt')
        outcome = "  Success" if succeeded else "  Failure"
        print("#"+num+outcome,"in {:f}s".format(end_time-start_time))

test()
                        








View More Similar Problems

Subsequence Weighting

A subsequence of a sequence is a sequence which is obtained by deleting zero or more elements from the sequence. You are given a sequence A in which every element is a pair of integers i.e A = [(a1, w1), (a2, w2),..., (aN, wN)]. For a subseqence B = [(b1, v1), (b2, v2), ...., (bM, vM)] of the given sequence : We call it increasing if for every i (1 <= i < M ) , bi < bi+1. Weight(B) =

View Solution →

Kindergarten Adventures

Meera teaches a class of n students, and every day in her classroom is an adventure. Today is drawing day! The students are sitting around a round table, and they are numbered from 1 to n in the clockwise direction. This means that the students are numbered 1, 2, 3, . . . , n-1, n, and students 1 and n are sitting next to each other. After letting the students draw for a certain period of ti

View Solution →

Mr. X and His Shots

A cricket match is going to be held. The field is represented by a 1D plane. A cricketer, Mr. X has N favorite shots. Each shot has a particular range. The range of the ith shot is from Ai to Bi. That means his favorite shot can be anywhere in this range. Each player on the opposite team can field only in a particular range. Player i can field from Ci to Di. You are given the N favorite shots of M

View Solution →

Jim and the Skyscrapers

Jim has invented a new flying object called HZ42. HZ42 is like a broom and can only fly horizontally, independent of the environment. One day, Jim started his flight from Dubai's highest skyscraper, traveled some distance and landed on another skyscraper of same height! So much fun! But unfortunately, new skyscrapers have been built recently. Let us describe the problem in one dimensional space

View Solution →

Palindromic Subsets

Consider a lowercase English alphabetic letter character denoted by c. A shift operation on some c turns it into the next letter in the alphabet. For example, and ,shift(a) = b , shift(e) = f, shift(z) = a . Given a zero-indexed string, s, of n lowercase letters, perform q queries on s where each query takes one of the following two forms: 1 i j t: All letters in the inclusive range from i t

View Solution →

Counting On a Tree

Taylor loves trees, and this new challenge has him stumped! Consider a tree, t, consisting of n nodes. Each node is numbered from 1 to n, and each node i has an integer, ci, attached to it. A query on tree t takes the form w x y z. To process a query, you must print the count of ordered pairs of integers ( i , j ) such that the following four conditions are all satisfied: the path from n

View Solution →