# Rooted Tree

### Problem Statement :

```You are given a rooted tree with N nodes and the root of the tree, R, is also given. Each node of the tree contains a value, that is initially empty. You have to mantain the tree under two operations:

Update Operation
Report Operation
Update Operation
Each Update Operation begins with the character U. Character U is followed by 3 integers T, V and K. For every node which is the descendent of the node T, update it's value by adding V + d*K, where V and K are the parameters of the query and d is the distance of the node from T. Note that V is added to node T.

Report Operation
Each Report Operation begins with the character Q. Character Q is followed by 2 integers, A and B. Output the sum of values of nodes in the path from A to B modulo (109 + 7)

Input Format

The first Line consists of 3 space separated integers, N E R, where N is the number of nodes present, E is the total number of queries (update + report), and R is root of the tree.

Each of the next N-1 lines contains 2 space separated integers, X and Y (X and Y are connected by an edge).

Thereafter, E lines follows: each line can represent either the Update Operation or the Report Operation.

Update Operation is of the form : U T V K.
Report Operation is of the form : Q A B.
Output Format
Output the answer for every given report operation.

Constraints

1 ≤ N, E ≤ 105
1 ≤ E ≤ 105
1 ≤ R, X, Y, T, A, B ≤ N
1 ≤ V, K ≤ 109
X ≠ Y```

### Solution :

```                            ```Solution in C :

In    C++  :

#include <cstdio>
#include <cmath>
#include <iostream>
#include <set>
#include <algorithm>
#include <vector>
#include <map>
#include <cassert>
#include <string>
#include <cstring>

using namespace std;

#define rep(i,a,b) for(int i = a; i < b; i++)
#define S(x) scanf("%d",&x)
#define P(x) printf("%d\n",x)

typedef long long int LL;
const int mod = 1000000007;
const int MAXN = 100005;
vector<int > g[MAXN];
int dep[MAXN];
int P[MAXN];
int _tm;
int tin[2*MAXN];
int tout[2*MAXN];
int n;
int L[MAXN][25];

LL bit1[2*MAXN], bit2[2*MAXN], bit3[2*MAXN];

LL _pow(LL a, LL b) {
if(!b) return 1;
if(b == 1) return a;
if(b == 2) return (a*a) % mod;
if(b&1) return (a*_pow(a,b-1)) % mod;
return _pow(_pow(a,b/2),2);
}

void dfs(int c, int p, int d) {
P[c] = p;
dep[c] = d;
_tm++;
tin[c] = _tm;
rep(i,0,g[c].size()) {
int u = g[c][i];
if(u != p) dfs(u, c, d+1);
}
_tm++;
tout[c] = _tm;
}

void processLca() {

int i, j;

//we initialize every element in P with -1
int N = n;
for (i = 0; i < n; i++)
for (j = 0; 1 << j < N; j++)
L[i][j] = -1;

//the first ancestor of every node i is T[i]
for (i = 0; i < N; i++)
L[i][0] = P[i];

//bottom up dynamic programing
for (j = 1; 1 << j < N; j++)
for (i = 0; i < N; i++)
if (L[i][j - 1] != -1)
L[i][j] = L[L[i][j - 1]][j - 1];

}

int lca(int p, int q)
{
int tmp, log, i;

//if p is situated on a higher level than q then we swap them
if (dep[p] < dep[q])
tmp = p, p = q, q = tmp;

//we compute the value of [log(L[p)]
for (log = 1; 1 << log <= dep[p]; log++);
log--;

//we find the ancestor of node p situated on the same level
//with q using the values in P
for (i = log; i >= 0; i--)
if (dep[p] - (1 << i) >= dep[q])
p = L[p][i];

if (p == q)
return p;

//we compute LCA(p, q) using the values in P
for (i = log; i >= 0; i--)
if (L[p][i] != -1 && L[p][i] != L[q][i])
p = L[p][i], q = L[q][i];

return P[p];
}

void update(LL *bit, int idx, LL val) {
for(int i = idx; i <= _tm; i += i & -i) bit[i] += val;
}

LL query(LL *bit, int idx) {
LL res = 0;
for(int i = idx; i; i -= i & -i) {
res += bit[i];
}
return res % mod;
}

LL QQQ(int x) {
LL res;
LL c = dep[x];
res = (query(bit1, tin[x]) * c) % mod;
res += (query(bit2, tin[x]) * (((LL)c*c)%mod));
res %= mod;
res += query(bit3, tin[x]);
return res % mod;
}

int main() {
int e,r;
scanf("%d%d%d",&n,&e,&r);
r--;

rep(i,0,n-1) {
int x,y;
scanf("%d%d",&x,&y);
x--;y--;
g[x].push_back(y);
g[y].push_back(x);
}
dfs(r,-1,0);
processLca();

while(e--) {
char s[5];
scanf("%s",s);
if(s[0] == 'U') {
int T,V,K;
scanf("%d%d%d",&T,&V,&K);
T--;
LL k = ((LL)K * _pow(2,mod-2)) % mod;
LL p = dep[T];
LL val;
// printf("%d %d %lld %lld\n",tin[T],tout[T],k,p);

val = (V-2*p*k+k) % mod;
val = (val + mod) % mod;
// printf("%lld\n",val);
update(bit1, tin[T], val);
update(bit1, tout[T]+1, -val);

val = k;
// printf("%lld\n",val);
update(bit2, tin[T], val);
update(bit2, tout[T]+1, -val);

val = (p*p) % mod;
val = (val*k) % mod;
val -= p*(V+k);
val %= mod;
val += mod+V;
val %= mod;
// printf("%lld\n",val);
update(bit3, tin[T], val);
update(bit3, tout[T]+1, -val);

} else {
int A,B;
scanf("%d%d",&A,&B);
A--;B--;
LL ans = 0;
int l = lca(A,B);

ans = QQQ(A)+QQQ(B)-QQQ(l);
if(P[l] != -1) ans -= QQQ(P[l]);
// printf("%lld %lld %lld %d\n",QQQ(A),QQQ(B),QQQ(l),l);
ans %= mod;
ans += mod;
ans %= mod;
printf("%lld\n",ans);
}
}
return 0;
}

In    Java  :

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.BitSet;
import java.util.InputMismatchException;

public class Solution {
static InputStream is;
static PrintWriter out;
static String INPUT = "";
static int mod = 1000000007;

static void solve()
{
int n = ni(), Q = ni(), root = ni()-1;
int[] from = new int[n-1]
int[] to = new int[n-1];
for(int i = 0;i < n-1;i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] g = packU(n, from, to);
int[][] pars = parents3(g, root);
int[] par = pars[0], dep = pars[2];
int[][] rights = makeRights(g, par, root);
int[] ord = rights[0], iord = rights[1], right = rights[2];
int[][] spar = logstepParents(par);
//        tr(ord);

long[] f2 = new long[n+2];
long[] f1 = new long[n+2];
long[] f0 = new long[n+2];

long i2 = invl(2, mod);
for(int z = 0;z < Q;z++){
char t = nc();
if(t == 'U'){
int tar = ni()-1;
long v = ni(), K = ni();
long c = dep[tar];
long c2 = K;
long c1 = (2*v + (long)(-2*c+1)*K) % mod;
long c0 = ((-c+1)*v*2 + (-c)*(-c+1)%mod*K)%mod;
}else if(t == 'Q'){
int a = ni()-1, b = ni()-1;
int lca = lca2(a, b, spar, dep);
int plca = par[lca];

long vala = val(a, f2, f1, f0, iord, dep);
long valb = val(b, f2, f1, f0, iord, dep);
long vall = val(lca, f2, f1, f0, iord, dep);
long valpl = plca == -1 ? 0L : val(plca, f2, f1, f0, iord, dep);
long ret = (vala + valb - vall - valpl) * i2 % mod;
if(ret < 0)ret += mod;
//                tr(vala, valb, vall, valpl, a, b, lca, plca);
out.println(ret);
}

}
}

public static long invl(long a, long mod)
{
long b = mod;
long p = 1, q = 0;
while(b > 0){
long c = a / b;
long d;
d = a; a = b; b = d % b;
d = p; p = q; q = d - c * q;
}
return p < 0 ? p + mod : p;
}

public static long[] restoreFenwick(long[] ft)
{
int n = ft.length-1;
long[] ret = new long[n];
for(int i = 0;i < n;i++)ret[i] = sumFenwick(ft, i);
for(int i = n-1;i >= 1;i--)ret[i] -= ret[i-1];
return ret;
}

static long val(int a, long[] f2, long[] f1,
long[] f0, int[] iord, int[] dep){
return
((sumFenwick(f2, iord[a])%mod*dep[a] +
sumFenwick(f1, iord[a]))%mod*dep[a] +
sumFenwick(f0, iord[a]))%mod;
}

public static long sumFenwick(long[] ft, int i)
{
long sum = 0;
for(i++;i > 0;i -= i&-i)sum += ft[i];
return sum;
}

public static void addFenwick(long[] ft, int i, long v)
{
if(v == 0)return;
int n = ft.length;
for(i++;i < n;i += i&-i)ft[i] += v;
}

public static int lca2(int a, int b, int[][] spar, int[] depth) {
if(depth[a] < depth[b]){
b = ancestor(b, depth[b] - depth[a], spar);
}else if(depth[a] > depth[b]){
a = ancestor(a, depth[a] - depth[b], spar);
}

if(a == b)
return a;
int sa = a, sb = b;
for(int low = 0, high = depth[a], t =
Integer.highestOneBit(high), k = Integer
.numberOfTrailingZeros(t);t > 0;t >>>= 1, k--){
if((low ^ high) >= t){
if(spar[k][sa] != spar[k][sb]){
low |= t;
sa = spar[k][sa];
sb = spar[k][sb];
}else{
high = low | t - 1;
}
}
}
return spar[0][sa];
}

protected static int ancestor(int a,
int m, int[][] spar) {
for(int i = 0;m > 0 && a != -1;m >>>= 1, i++){
if((m & 1) == 1)
a = spar[i][a];
}
return a;
}

public static int[][] logstepParents(int[] par) {
int n = par.length;
int m = Integer.numberOfTrailingZeros(
Integer.highestOneBit(n - 1)) + 1;
int[][] pars = new int[m][n];
pars[0] = par;
for(int j = 1;j < m;j++){
for(int i = 0;i < n;i++){
pars[j][i] = pars[j - 1][i] == -1 ? -1
: pars[j - 1][pars[j - 1][i]];
}
}
return pars;
}

public static int[][] makeRights(int[][] g, int[] par, int root)
{
int n = g.length;
int[] ord = sortByPreorder(g, root);
int[] iord = new int[n];
for(int i = 0;i < n;i++)iord[ord[i]] = i;

int[] right = new int[n];
for(int i = n-1;i >= 0;i--){
int v = i;
for(int e : g[ord[i]]){
if(e != par[ord[i]]){
v = Math.max(v, right[iord[e]]);
}
}
right[i] = v;
}
return new int[][]{ord, iord, right};
}

public static int[] sortByPreorder(int[][] g, int root){
int n = g.length;
int[] stack = new int[n];
int[] ord = new int[n];
BitSet ved = new BitSet();
stack[0] = root;
int p = 1;
int r = 0;
ved.set(root);
while(p > 0){
int cur = stack[p-1];
ord[r++] = cur;
p--;
for(int e : g[cur]){
if(!ved.get(e)){
stack[p++] = e;
ved.set(e);
}
}
}
return ord;
}

public static int[][] parents3(int[][] g, int root) {
int n = g.length;
int[] par = new int[n];
Arrays.fill(par, -1);

int[] depth = new int[n];
depth[0] = 0;

int[] q = new int[n];
q[0] = root;
for(int p = 0, r = 1;p < r;p++){
int cur = q[p];
for(int nex : g[cur]){
if(par[cur] != nex){
q[r++] = nex;
par[nex] = cur;
depth[nex] = depth[cur] + 1;
}
}
}
return new int[][] { par, q, depth };
}

static int[][] packU(int n, int[] from, int[] to) {
int[][] g = new int[n][];
int[] p = new int[n];
for(int f : from)
p[f]++;
for(int t : to)
p[t]++;
for(int i = 0;i < n;i++)
g[i] = new int[p[i]];
for(int i = 0;i < from.length;i++){
g[from[i]][--p[from[i]]] = to[i];
g[to[i]][--p[to[i]]] = from[i];
}
return g;
}

public static void main(String[] args) throws Exception
{
long S = System.currentTimeMillis();
is = INPUT.isEmpty() ? System.in :
new ByteArrayInputStream(INPUT.getBytes());
out = new PrintWriter(System.out);

solve();
out.flush();
long G = System.currentTimeMillis();
tr(G-S+"ms");
}

private static boolean eof()
{
if(lenbuf == -1)return true;
int lptr = ptrbuf;
while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))
return false;

try {
is.mark(1000);
while(true){
int b = is.read();
if(b == -1){
is.reset();
return true;
}else if(!isSpaceChar(b)){
is.reset();
return false;
}
}
} catch (IOException e) {
return true;
}
}

private static byte[] inbuf = new byte[1024];
static int lenbuf = 0, ptrbuf = 0;

private static int readByte()
{
if(lenbuf == -1)throw new InputMismatchException();
if(ptrbuf >= lenbuf){
ptrbuf = 0;
try { lenbuf = is.read(inbuf); }
catch (IOException e) { throw new InputMismatchException(); }
if(lenbuf <= 0)return -1;
}
return inbuf[ptrbuf++];
}

private static boolean isSpaceChar(int c)
{ return !(c >= 33 && c <= 126); }
private static int skip()
{ int b; while(
(b = readByte()) != -1 && isSpaceChar(b)); return b; }

private static double nd() { return Double.parseDouble(ns()); }
private static char nc() { return (char)skip(); }

private static String ns()
{
int b = skip();
StringBuilder sb = new StringBuilder();
while(!(isSpaceChar(b)))
{ // when nextLine, (isSpaceChar(b) && b != ' ')
sb.appendCodePoint(b);
}
return sb.toString();
}

private static char[] ns(int n)
{
char[] buf = new char[n];
int b = skip(), p = 0;
while(p < n && !(isSpaceChar(b))){
buf[p++] = (char)b;
}
return n == p ? buf : Arrays.copyOf(buf, p);
}

private static char[][] nm(int n, int m)
{
char[][] map = new char[n][];
for(int i = 0;i < n;i++)map[i] = ns(m);
return map;
}

private static int[] na(int n)
{
int[] a = new int[n];
for(int i = 0;i < n;i++)a[i] = ni();
return a;
}

private static int ni()
{
int num = 0, b;
boolean minus = false;
while((b = readByte()) != -1 && !(
(b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
}

while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
}
}

private static long nl()
{
long num = 0;
int b;
boolean minus = false;
while((b = readByte()) != -1 && !(
(b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
}

while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
}
}

private static void tr(Object... o)
{ if(INPUT.length() != 0)
System.out.println(Arrays.deepToString(o)); }
}

In    Python3  :

class Node(object):
def __init__(self, value=0, parent=None):
self.children=set()
self.parent=parent
self.value=value
self.update=set()
def setParent(self, node):
self.parent=node

def root_path(bottom):
chain=[]
while bottom is not None:
chain.append(bottom)
bottom=bottom.parent
return chain

def height(bottom, top):
res=1
while bottom!=top and bottom is not None:
bottom=bottom.parent
res+=1
return res

def q2(arr, nodes):
a,b=arr
first=nodes[a-1]
last=nodes[b-1]
path=root_path(last)
h=height(last,first)
res=0
while len(path)>0:
h=min(h, len(path))
d=max(len(path)-h,0)
c=path.pop()
for u in c.update:
for i in range(h):
res+=(u[0] + (d+i)*u[1])%(10**9+7)
return res % (10**9+7)

def isRootBetween(bottom, top):
if bottom==top:
return False
if bottom in root_path(top):
return False
if top in root_path(bottom):
return False
return True

def isDeeper(x,y):
p=y
while p is not None:
if x==p:
return False
p=p.parent
return True

def query(arr, nodes, root):
a,b=arr
x=nodes[a-1]
y=nodes[b-1]
if a==b==root:
return q2([root,root], nodes)
if isRootBetween(x,y):
return (q2([root,a], nodes) \
+q2([root,b], nodes)\
-q2([root,root], nodes)) % (10**9+7)
elif isDeeper(x,y):
return q2([b,a], nodes)
else:
return q2([a,b], nodes)

def update(arr, nodes):
node=nodes[arr[0]-1]

line1=[int(x) for x in input().split(' ')]
nodes=[Node() for i in range(line1[0])]
root=line1[2]-1

if d.get(k) is None:
d[k]=[v]
else:
d[k]+={v}

tmp=dict()
for i in range(line1[0]-1):
p,c=[int(x) for x in input().split(' ')]
visited=[]
parents=[root]
while len(parents)>0:
current=parents.pop()
ltmp=tmp.get(current)
if ltmp is not None:
for c in ltmp:
if c not in visited:
nodes[c].setParent(nodes[current])
parents.append(c)
visited.append(current)

for i in range(line1[1]):
inp=input().split(' ')
t=inp[0]
integers=[int(x) for x in inp[1:]]
if t=='U':
update(integers, nodes)
else:
print(query(integers, nodes, root+1))```
```

