Super Maximum Cost Queries

Problem Statement :

Victoria has a tree, T , consisting of N nodes numbered from 1 to N. Each edge from node Ui to Vi  in tree T has an integer weight, Wi.

Let's define the cost, C, of a path from some node X to some other node Y as the maximum weight ( W ) for any edge in the unique path from node X to Y node .

Victoria wants your help processing Q queries on tree T, where each query contains 2 integers, L and R, such that L <= R . For each query, she wants to print the number of different paths in T  that have a cost, C , in the inclusive range [ L , R ] .

It should be noted that path from some node X to some other node Y is considered same as path from node Y to X  i.e  { X, Y }is same as { Y, X } .

Input Format

The first line contains 2 space-separated integers, N (the number of nodes) and Q (the number of queries), respectively.
Each of the N -1  subsequent lines contain 3 space-separated integers, U , V , and W, respectively, describing a bidirectional road between nodes U and V which has weight W.
The Q subsequent lines each contain 2 space-separated integers denoting L and R.


1  <=  N,  Q  < = 10 ^5
1  <=   U, V  <=  N
1  <=  W   <=  10 ^ 9
1  <=  L  <=  R  <=  10^9

Output Format

For each of the Q queries, print the number of paths in T having cost C  in the inclusive range  [ L, R ] on a new line.

Solution :


                            Solution in C :

In C++ :

#define PB push_back
#define MP make_pair
#define sz(v) (in((v).size()))
#define forn(i,n) for(in i=0;i<(n);++i)
#define forv(i,v) forn(i,sz(v))
#define fors(i,s) for(auto i=(s).begin();i!=(s).end();++i)
#define all(v) (v).begin(),(v).end()
using namespace std;
typedef long long in;
typedef vector<in> VI;
typedef vector<VI> VVI;
struct unifnd{
  VI ht,pr,ss;
  in fnd(in a){
    in ta=a;
    in tt=ta;
    return a;
  in uni(in a, in b){
    if(a==b)return 0;
    in r=ss[a]*ss[b];
    return r;
  void ini(in n){
VI ans;
unifnd cf;
struct ev{
  in typ,u,v,w;
  ev(in a=0, in b=0, in c=0, in d=0){
  bool operator<(const ev cp)const{
      return w<cp.w;
      return typ<cp.typ;
    return 0;
in sm=0;
void prev(ev tp){
vector<ev> evs;
int main(){
  in n,q;
  in ta,tb,tc;
  return 0;

In Java :

import java.util.Arrays;
import java.util.Comparator;
import java.util.Scanner;

import static java.lang.System.out;

class WeightCount {
    private int weight;
    private long count;

    public WeightCount(int weight, long count) {
	this.weight = weight;
	this.count = count;

    public int getWeight() {
	return weight;
    public long getCount() {
	return count;
    public void setWeight(int weight) {
	this.weight = weight;
    public void setCount(long count) {
	this.count = count;
    public static int lower(WeightCount[] array, int size, int key) {
	if (array == null || size < 0)
	    return -1;

	if (size == 0)
	    return 0;
	int l = 0;
	int r = size - 1;

	int mid, weight;	
	while ((r - l) > 1) {
	    mid = l + ((r - l) >> 1);
	    weight = array[mid].getWeight();
	    if (weight > key)
		r = mid - 1;
	    else if (weight < key)
		l = mid;
		r = mid;		 
	if (array[l].getWeight() > key)
	    return l - 1;
	if (key == array[l].getWeight() ||
	    array[r].getWeight() > key)
	    return l;
	return r;

class Edge implements Comparable<Edge> {
    private int u;
    private int v;
    private int w;

    public Edge(int u, int v, int w) {
	this.u = u;
	this.v = v;
	this.w = w;

    public int getU() {
	return u;	
    public int getV() {
	return v;	
    public int getW() {
	return w;	
    public void setU(int u) {
	this.u = u;
    public void setV(int v) {
	this.v = v;
    public void setW(int w) {
	this.w = w;
    public int compareTo(Edge e) {
	if (e != null) {
	    int tmp = e.getW();
	    if (w < tmp)
		return -1;
	    if (w > tmp)
		return 1;	    

	return 0;

class DisjointSet {
    private static final int DEFAULT_SIZE = 31;

    private int[] idx;
    private int[] size;
    private int n;
    private int components;

    public DisjointSet(int n) {
	if (n < 1)
	    n = DEFAULT_SIZE;

	idx = new int[n + 1];
	size = new int[n + 1];

	this.n = n;
	components = n;

	for (int i = n; i > 0; i--) {
	    idx[i] = i;
	    size[i] = 1;

    public DisjointSet() {

    private int root(int i) {
	if (i < 1 || i > n)
	    return 0;

	int p = i;
	while (idx[p] != p)
	    p = idx[p];

	int tmp;
	while (idx[i] != p) {
	    tmp = idx[i];
	    idx[i] = p;
	    i = tmp;

	return p;

    public long join(int p, int q) {
	int rootP = root(p);
	int rootQ = root(q);

	if (rootP != rootQ) {
	    long result = (long) size[rootP] * size[rootQ];
	    if (size[rootP] < size[rootQ]) {
		idx[rootP] = rootQ;
		size[rootQ] += size[rootP];
	    } else {
		idx[rootQ] = rootP;
		size[rootP] += size[rootQ];

	    return result;

	return 0;

    public boolean isConnected(int p, int q) {
	return (root(p) == root(q));

public class MaximumCostQueries {
    private static final int MAX_N = 100000;
    private static final int MAX_Q = 100000;

    public static void main(String[] args) {
	Scanner sc = new Scanner(;

	int n = sc.nextInt();
	int q = sc.nextInt();
	if (n < 1 || n > MAX_N ||
	    q < 1 || q > MAX_Q)

	Edge[] edges = new Edge[n - 1];
	int i, u, v, w;	
	for (i = n - 2; i >= 0; i--) {
	    u = sc.nextInt();
	    v = sc.nextInt();
	    w = sc.nextInt();
	    edges[i] = new Edge(u, v, w);
	DisjointSet ds = new DisjointSet(n);
	WeightCount[] wc = new WeightCount[n];
	int j, k, limit;
	long result;

	limit = edges.length;
	for (k = i = 0; i < limit; i = j) {
	    result = 0;
	    w = edges[i].getW();
	    j = i;

	    do {
		result += ds.join(edges[j].getU(), edges[j].getV());
	    } while (j < limit && edges[j].getW() == w);

	    wc[k++] = new WeightCount(w, result);

	// out.println("wc:");
	// for (i = 0; i < k; i++)
	//     out.println(wc[i].getWeight() + ": " + wc[i].getCount());

	for (i = 1; i < k; i++)
	    wc[i].setCount(wc[i - 1].getCount() + wc[i].getCount());

	while (q-- > 0) {
	    i = sc.nextInt();
	    j = sc.nextInt();

	    u = WeightCount.lower(wc, k, i - 1);
	    v = WeightCount.lower(wc, k, j);

	    result = wc[v].getCount() - ((u < 0) ? 0 : wc[u].getCount());

In C :

#include <stdio.h>

static long long int a[100000][3], parent[100001], n;
void mer(int p, int q, int r)
	static int le[100001][3], ri[100001][3], i, j, k;
	int n1 = q - p + 1, n2 = r - q;
	for (i = 0; i<n1; i++)
		le[i][0] = a[p + i][0];
		le[i][1] = a[p + i][1];
		le[i][2] = a[p + i][2];
	for (j = 0; j<n2; j++)
		ri[j][0] = a[q + j + 1][0];
		ri[j][1] = a[q + j + 1][1];
		ri[j][2] = a[q + j + 1][2];
	le[n1][2] = ri[n2][2] = 1000000001;
	i = j = 0;
	for (k = p; k <= r; k++)
		if (le[i][2] <= ri[j][2])
			a[k][0] = le[i][0];
			a[k][1] = le[i][1];
			a[k][2] = le[i][2];
			a[k][0] = ri[j][0];
			a[k][1] = ri[j][1];
			a[k][2] = ri[j][2];
void merge_sort(int p, int r)
	int q;
	if (p<r)
		q = (p + r) / 2;
		merge_sort(p, q);
		merge_sort(q + 1, r);
		mer(p, q, r);

int getParent(int x)
	if (x == parent[x])
		return x;
	parent[x] = getParent(parent[x]);
	return parent[x];

int bin_search(int x)
	int low, mid, upp;
	low = 0;
	upp = n;
	mid = (low + upp) / 2;
	while (low<upp)
		if (x<a[mid][2])
			upp = mid - 1;
		else if (x>a[mid][2])
			if (x >= a[mid + 1][2])
				low = mid + 1;
			if (x<a[mid + 1][2])
		mid = (low + upp) / 2;
	return mid;

int main() {
	static long long int q, count[100001], i, j, k, x, y, z, px, py;
	scanf("%lld%lld", &n, &q);
	for (i = 1; i<n; i++)
		parent[i] = i;
		count[i] = 1;
		scanf("%lld%lld%lld", &a[i][0], &a[i][1], &a[i][2]);
	parent[i] = i;
	count[i] = 1;
	merge_sort(1, n - 1);
	for (i = 1; i<n; i++)
		x = a[i][0];
		y = a[i][1];
		px = getParent(x);
		py = getParent(y);
		a[i][0] = count[px] * count[py];
		if (count[px] >= count[py])
			count[px] += count[py];
			parent[py] = px;
			count[py] += count[px];
			parent[px] = py;
	for (i = 2, j = 1; i<n; i++)
		if (a[j][2] == a[i][2])
			a[j][0] += a[i][0];
			a[j][0] = a[i][0] + a[j - 1][0];
			a[j][2] = a[i][2];
	n = j;
	while (q--)
		scanf("%lld%lld", &x, &y);
		px = bin_search(x - 1);
		py = bin_search(y);
		printf("%lld\n", a[py][0] - a[px][0]);
	return 0;

In Python3 :

import bisect as bs
n, q = map(int, input().split())
edges = [list(map(int, input().split())) for _ in range(1, n)]
edges.sort(key=lambda x: x[2])
paths = {}
union = [-1] * n

def getroot(x):
    if union[x] < 0:
        return x
    union[x] = getroot(union[x])
    return union[x]

for u, v, c in edges:
    u = getroot(u - 1)
    v = getroot(v - 1)
    paths[c] = paths.get(c, 0) + union[u] * union[v]
    if union[u] < union[v]:
        u, v = v, u
    union[v] += union[u]
    union[u] = v

paths = list(sorted(paths.items()))
a = [0]
b =[0]
for x, y in paths:
    b.append(b[-1] + y)

for _ in range(q):
    l, r = map(int, input().split())
    print(b[bs.bisect(a, r) - 1] - b[bs.bisect_left(a, l) - 1])

