# Super Maximum Cost Queries

### Problem Statement :

```Victoria has a tree, T , consisting of N nodes numbered from 1 to N. Each edge from node Ui to Vi  in tree T has an integer weight, Wi.

Let's define the cost, C, of a path from some node X to some other node Y as the maximum weight ( W ) for any edge in the unique path from node X to Y node .

Victoria wants your help processing Q queries on tree T, where each query contains 2 integers, L and R, such that L <= R . For each query, she wants to print the number of different paths in T  that have a cost, C , in the inclusive range [ L , R ] .

It should be noted that path from some node X to some other node Y is considered same as path from node Y to X  i.e  { X, Y }is same as { Y, X } .

Input Format

The first line contains 2 space-separated integers, N (the number of nodes) and Q (the number of queries), respectively.
Each of the N -1  subsequent lines contain 3 space-separated integers, U , V , and W, respectively, describing a bidirectional road between nodes U and V which has weight W.
The Q subsequent lines each contain 2 space-separated integers denoting L and R.

Constraints

1  <=  N,  Q  < = 10 ^5
1  <=   U, V  <=  N
1  <=  W   <=  10 ^ 9
1  <=  L  <=  R  <=  10^9

Output Format

For each of the Q queries, print the number of paths in T having cost C  in the inclusive range  [ L, R ] on a new line.```

### Solution :

```                            ```Solution in C :

In C++ :

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<set>
#include<map>
#include<queue>
#include<cassert>
#define PB push_back
#define MP make_pair
#define sz(v) (in((v).size()))
#define forn(i,n) for(in i=0;i<(n);++i)
#define forv(i,v) forn(i,sz(v))
#define fors(i,s) for(auto i=(s).begin();i!=(s).end();++i)
#define all(v) (v).begin(),(v).end()
using namespace std;
typedef long long in;
typedef vector<in> VI;
typedef vector<VI> VVI;
struct unifnd{
VI ht,pr,ss;
in fnd(in a){
in ta=a;
while(a!=pr[a])a=pr[a];
in tt=ta;
while(ta!=a){
tt=pr[ta];
pr[ta]=a;
ta=tt;
}
return a;
}
in uni(in a, in b){
a=fnd(a);
b=fnd(b);
if(a==b)return 0;
if(ht[b]<ht[a])swap(a,b);
pr[a]=b;
in r=ss[a]*ss[b];
ss[b]+=ss[a];
ht[b]+=(ht[a]==ht[b]);
return r;
}
void ini(in n){
ht.resize(n);
pr.resize(n);
ss.resize(n);
forn(i,n){
ht[i]=0;
ss[i]=1;
pr[i]=i;
}
}
};
VI ans;
unifnd cf;
struct ev{
in typ,u,v,w;
ev(in a=0, in b=0, in c=0, in d=0){
typ=a;
u=b;
v=c;
w=d;
}
bool operator<(const ev cp)const{
if(w!=cp.w)
return w<cp.w;
if(typ!=cp.typ)
return typ<cp.typ;
return 0;
}
};
in sm=0;
void prev(ev tp){
if(tp.typ==1){
ans[tp.u]+=tp.v*sm;
return;
}
sm+=cf.uni(tp.u,tp.v);
}
vector<ev> evs;
int main(){
ios::sync_with_stdio(0);
cin.tie(0);
in n,q;
cin>>n>>q;
cf.ini(n);
ans.resize(q,0);
in ta,tb,tc;
forn(i,n-1){
cin>>ta>>tb>>tc;
--ta;
--tb;
evs.PB(ev(0,ta,tb,tc));
}
forn(i,q){
cin>>ta>>tb;
evs.PB(ev(1,i,-1,ta-1));
evs.PB(ev(1,i,1,tb));
}
sort(all(evs));
forv(i,evs)
prev(evs[i]);
forv(i,ans)
cout<<ans[i]<<"\n";
return 0;
}

In Java :

import java.util.Arrays;
import java.util.Comparator;
import java.util.Scanner;

import static java.lang.System.out;

class WeightCount {
private int weight;
private long count;

public WeightCount(int weight, long count) {
this.weight = weight;
this.count = count;
}

public int getWeight() {
return weight;
}
public long getCount() {
return count;
}

public void setWeight(int weight) {
this.weight = weight;
}
public void setCount(long count) {
this.count = count;
}

public static int lower(WeightCount[] array, int size, int key) {
if (array == null || size < 0)
return -1;

if (size == 0)
return 0;

int l = 0;
int r = size - 1;

int mid, weight;
while ((r - l) > 1) {
mid = l + ((r - l) >> 1);
weight = array[mid].getWeight();
if (weight > key)
r = mid - 1;
else if (weight < key)
l = mid;
else
r = mid;
}

if (array[l].getWeight() > key)
return l - 1;
if (key == array[l].getWeight() ||
array[r].getWeight() > key)
return l;
return r;
}
}

class Edge implements Comparable<Edge> {
private int u;
private int v;
private int w;

public Edge(int u, int v, int w) {
this.u = u;
this.v = v;
this.w = w;
}

public int getU() {
return u;
}
public int getV() {
return v;
}
public int getW() {
return w;
}

public void setU(int u) {
this.u = u;
}
public void setV(int v) {
this.v = v;
}
public void setW(int w) {
this.w = w;
}

public int compareTo(Edge e) {
if (e != null) {
int tmp = e.getW();
if (w < tmp)
return -1;
if (w > tmp)
return 1;
}

return 0;
}
}

class DisjointSet {
private static final int DEFAULT_SIZE = 31;

private int[] idx;
private int[] size;
private int n;
private int components;

public DisjointSet(int n) {
if (n < 1)
n = DEFAULT_SIZE;

idx = new int[n + 1];
size = new int[n + 1];

this.n = n;
components = n;

for (int i = n; i > 0; i--) {
idx[i] = i;
size[i] = 1;
}
}

public DisjointSet() {
this(DEFAULT_SIZE);
}

private int root(int i) {
if (i < 1 || i > n)
return 0;

int p = i;
while (idx[p] != p)
p = idx[p];

int tmp;
while (idx[i] != p) {
tmp = idx[i];
idx[i] = p;
i = tmp;
}

return p;
}

public long join(int p, int q) {
int rootP = root(p);
int rootQ = root(q);

if (rootP != rootQ) {
long result = (long) size[rootP] * size[rootQ];

if (size[rootP] < size[rootQ]) {
idx[rootP] = rootQ;
size[rootQ] += size[rootP];
} else {
idx[rootQ] = rootP;
size[rootP] += size[rootQ];
}

components--;
return result;
}

return 0;
}

public boolean isConnected(int p, int q) {
return (root(p) == root(q));
}
}

public class MaximumCostQueries {
private static final int MAX_N = 100000;
private static final int MAX_Q = 100000;

public static void main(String[] args) {
Scanner sc = new Scanner(System.in);

int n = sc.nextInt();
int q = sc.nextInt();
if (n < 1 || n > MAX_N ||
q < 1 || q > MAX_Q)
return;

Edge[] edges = new Edge[n - 1];
int i, u, v, w;
for (i = n - 2; i >= 0; i--) {
u = sc.nextInt();
v = sc.nextInt();
w = sc.nextInt();

edges[i] = new Edge(u, v, w);
}

Arrays.sort(edges);

DisjointSet ds = new DisjointSet(n);
WeightCount[] wc = new WeightCount[n];
int j, k, limit;
long result;

limit = edges.length;
for (k = i = 0; i < limit; i = j) {
result = 0;
w = edges[i].getW();
j = i;

do {
result += ds.join(edges[j].getU(), edges[j].getV());
j++;
} while (j < limit && edges[j].getW() == w);

wc[k++] = new WeightCount(w, result);
}

// out.println("wc:");
// for (i = 0; i < k; i++)
//     out.println(wc[i].getWeight() + ": " + wc[i].getCount());

for (i = 1; i < k; i++)
wc[i].setCount(wc[i - 1].getCount() + wc[i].getCount());

while (q-- > 0) {
i = sc.nextInt();
j = sc.nextInt();

u = WeightCount.lower(wc, k, i - 1);
v = WeightCount.lower(wc, k, j);

result = wc[v].getCount() - ((u < 0) ? 0 : wc[u].getCount());
out.println(result);
}
sc.close();
}
}

In C :

#include <stdio.h>

static long long int a[100000][3], parent[100001], n;
void mer(int p, int q, int r)
{
static int le[100001][3], ri[100001][3], i, j, k;
int n1 = q - p + 1, n2 = r - q;
for (i = 0; i<n1; i++)
{
le[i][0] = a[p + i][0];
le[i][1] = a[p + i][1];
le[i][2] = a[p + i][2];
}
for (j = 0; j<n2; j++)
{
ri[j][0] = a[q + j + 1][0];
ri[j][1] = a[q + j + 1][1];
ri[j][2] = a[q + j + 1][2];
}
le[n1][2] = ri[n2][2] = 1000000001;
i = j = 0;
for (k = p; k <= r; k++)
{
if (le[i][2] <= ri[j][2])
{
a[k][0] = le[i][0];
a[k][1] = le[i][1];
a[k][2] = le[i][2];
i++;
}
else
{
a[k][0] = ri[j][0];
a[k][1] = ri[j][1];
a[k][2] = ri[j][2];
j++;
}
}
}
void merge_sort(int p, int r)
{
int q;
if (p<r)
{
q = (p + r) / 2;
merge_sort(p, q);
merge_sort(q + 1, r);
mer(p, q, r);
}
}

int getParent(int x)
{
if (x == parent[x])
return x;
parent[x] = getParent(parent[x]);
return parent[x];
}

int bin_search(int x)
{
int low, mid, upp;
low = 0;
upp = n;
mid = (low + upp) / 2;
while (low<upp)
{
if (x<a[mid][2])
{
upp = mid - 1;
}
else if (x>a[mid][2])
{
if (x >= a[mid + 1][2])
low = mid + 1;
if (x<a[mid + 1][2])
break;
}
else
{
break;
}
mid = (low + upp) / 2;
}
return mid;
}

int main() {
static long long int q, count[100001], i, j, k, x, y, z, px, py;
scanf("%lld%lld", &n, &q);
for (i = 1; i<n; i++)
{
parent[i] = i;
count[i] = 1;
scanf("%lld%lld%lld", &a[i][0], &a[i][1], &a[i][2]);
}
parent[i] = i;
count[i] = 1;
merge_sort(1, n - 1);
for (i = 1; i<n; i++)
{
x = a[i][0];
y = a[i][1];
px = getParent(x);
py = getParent(y);
a[i][0] = count[px] * count[py];
if (count[px] >= count[py])
{
count[px] += count[py];
parent[py] = px;
}
else
{
count[py] += count[px];
parent[px] = py;
}
}
for (i = 2, j = 1; i<n; i++)
{
if (a[j][2] == a[i][2])
{
a[j][0] += a[i][0];
}
else
{
j++;
a[j][0] = a[i][0] + a[j - 1][0];
a[j][2] = a[i][2];
}
}
n = j;
while (q--)
{
scanf("%lld%lld", &x, &y);
px = bin_search(x - 1);
py = bin_search(y);
printf("%lld\n", a[py][0] - a[px][0]);
}
return 0;
}

In Python3 :

import bisect as bs
n, q = map(int, input().split())
edges = [list(map(int, input().split())) for _ in range(1, n)]
edges.sort(key=lambda x: x[2])
paths = {}
union = [-1] * n

def getroot(x):
if union[x] < 0:
return x
union[x] = getroot(union[x])
return union[x]

for u, v, c in edges:
u = getroot(u - 1)
v = getroot(v - 1)
paths[c] = paths.get(c, 0) + union[u] * union[v]
if union[u] < union[v]:
u, v = v, u
union[v] += union[u]
union[u] = v

paths = list(sorted(paths.items()))
a = [0]
b =[0]
for x, y in paths:
a.append(x)
b.append(b[-1] + y)

for _ in range(q):
l, r = map(int, input().split())
print(b[bs.bisect(a, r) - 1] - b[bs.bisect_left(a, l) - 1])```
```

