Easy Addition


Problem Statement :


You are given a tree with N nodes and each has a value associated with it. You are given Q queries, each of which is either an update or a retrieval operation.

Initially all node values are zero.

The update query is of the format

a1 d1 a2 d2 A B
This means you'd have to add ( a1 + z * d1) * ( a2 - z * d2 ) * R^z  in all nodes in the path from A to B where  is the distance between the node and A.

The retrieval query is of the format

i j
You need to return the sum of the node values lying in the path from node i to node j modulo 1000000007.

Note:

First all update queries are given and then all retrieval queries.
Distance between 2 nodes is the shortest path length between them taking each edge weight as 1.
Input Format

The first line contains two integers (N and R respectively) separated by a space.

In the next N-1 lines, the ith line describes the ith edge: a line with two integers x y separated by a single space denotes an edge between nodes x and y.

The next line contains 2 space separated integers (U and Q respectively) representing the number of Update and Query operations to follow.

U lines follow. Each of the next U lines contains 6 space separated integers (a1,d1,a2,d2,A and B respectively).

Each of the next Q lines contains 2 space separated integers, i and j respectively.


Output Format

It contains exactly Q lines and each line containing the answer of the ith query.

Constraints

2 <= N <= 105
2 <= R <= 109
1 <= U <= 105
1 <= Q <= 105
1 <= a1,a2,d1,d2 <= 108
1 <= x, y, i, j, A, B <= N

Note For the update operation, x can be equal to y and for the query operation, i can be equal to j.



Solution :



title-img


                            Solution in C :

In   C++ :








#include<stdio.h>
#include<string.h>
#include<iostream>
#include<algorithm>
#include<stdlib.h>
#include<set>
#include<map>
#include<bitset>
#include<vector>
#include<assert.h>
using namespace std;
#define FOR(i,a,b) for(int i=(a);i<=(b);i++)
typedef long long int ll;
const int mod = 1000000007;
vector<int> myvector[100001];
int N,parent[100001],depth[100001],Q,U;
ll R,Rinv,Rp[100001],sum[100001];
int store[100001];
struct node{
ll S,D,B,factD,factB,addB;
node(ll a=0,ll b=0,ll c=0,ll d=0,ll e=0,ll f=0){
S=a,D=b,B=c;
factD=d;
factB=e;
addB=f;
}
node operator + (const node &x) const{
ll a,b,c,d,e,f;
a = (x.S + S);
b = (x.D + D);
c = (x.B + B);
d = (x.factD + factD);
e = (x.factB + factB);
f = (x.addB + addB);
a= (a>mod)?(a-mod):a;
b= (b>mod)?(b-mod):b;
c= (c>mod)?(c-mod):c;
d= (d>mod)?(d-mod):d;
e= (e>mod)?(e-mod):e;
f= (f>mod)?(f-mod):f;
return node(a,b,c,d,e,f);
}
}fwd[100001],bck[100001],A,B;
ll inverse(ll a,ll b)
{
ll Remainder,p0=0,p1=1,pcurr=1,q,m=b;
while(a!=0){
Remainder=b%a;
q=b/a;
if(Remainder!=0){
pcurr=p0-(p1*q)%m;
if(pcurr<0)
pcurr+=m;
p0=p1;
p1=pcurr;
}
b=a;
a=Remainder;
}
return pcurr;   
}
void dfs_pre(int root)
{
for(vector<int>::iterator it=myvector[root].begin();it!=myvector[root].end();it++){
if(parent[root]==*it)   continue;
parent[*it]=root;
depth[*it]=depth[root]+1;
dfs_pre(*it);
}
}

void dfs_cal(int root)          
{
for(vector<int>::iterator it=myvector[root].begin();it!=myvector[root].end();it++){
if(parent[root]==*it)   continue;
dfs_cal(*it);

fwd[root].S = (fwd[root].S + fwd[*it].S*R)%mod;
fwd[root].D = (fwd[root].D + (fwd[*it].D + fwd[*it].factD)*R)%mod;
fwd[root].factD = (fwd[root].factD + fwd[*it].factD*R)%mod;
fwd[root].B = (fwd[root].B +  (fwd[*it].B + fwd[*it].factB)*R) %mod;
fwd[root].factB = (fwd[root].factB + (fwd[*it].factB+fwd[*it].addB)*R)%mod;
fwd[root].addB = ( fwd[root].addB + fwd[*it].addB*R)%mod;

bck[root].S = (bck[*it].S * Rinv  + bck[root].S)%mod;
bck[root].D = (bck[root].D + (bck[*it].D - bck[*it].factD)*Rinv )%mod;
if (bck[root].D<0)  bck[root].D+=mod;
bck[root].factD = (bck[root].factD + bck[*it].factD*Rinv)%mod; 

bck[root].B = (bck[root].B + (bck[*it].B - bck[*it].factB)*Rinv) %mod; 
if(bck[root].B<0)   bck[root].B+=mod;
bck[root].factB = (bck[root].factB + (bck[*it].factB-bck[*it].addB)*Rinv)%mod;
if(bck[root].factB<0)    bck[root].factB+=mod;
bck[root].addB  = ( bck[root].addB + bck[*it].addB*Rinv)%mod;

}

sum[root] = (sum[root] + fwd[root].S+fwd[root].D+fwd[root].B + bck[root].S+bck[root].D+bck[root].B)%mod;
}

void dfs_sum(int root)          
{
for(vector<int>::iterator it=myvector[root].begin();it!=myvector[root].end();it++){
if(parent[root]==*it)   continue;
sum[*it]=(sum[*it]+sum[root])%mod;
dfs_sum(*it);
}
}
int Root[100001][18];
void init()
{
store[0]=0;store[1]=0;store[2]=1;
int cmp=4;
FOR(i,3,100000){
if(cmp>i)       store[i]=store[i-1];
else{
store[i]=store[i-1]+1;
cmp<<=1;
}
}
}
void process(int N)
{
memset(Root,-1,sizeof(Root));
for(int i=1;i<=N;i++)   Root[i][0]=parent[i];
for(int i=1;(1<<i)<=N;i++)
for(int j=1;j<=N;j++)
if(Root[j][i-1]!=-1)
Root[j][i]=Root[Root[j][i-1]][i-1];
}
int lca(int p,int q)
{
int temp;
if(depth[p]>depth[q])   swap(p,q);
int steps=store[depth[q]];
for(int i=steps;i>=0;i--)
if(depth[q]-(1<<i) >= depth[p])
q=Root[q][i];
if(p==q)    return p;
for(int i=steps;i>=0;i--){
if(Root[p][i]!=Root[q][i])
p=Root[p][i],q=Root[q][i];
}
return parent[p];
}
void Update_forward(ll S,ll D,ll B1,ll t,int x,int y)
{
ll brt;

A.S = S;
B.S = mod - (S*Rp[t])%mod;

A.factD = D;
A.D = 0;
B.factD = mod-(D*Rp[t])%mod;
if(B.factD<0)   B.factD+=mod;
B.D = (B.factD*t)%mod;

brt = B1;
A.B = 0;
A.factB = brt;
if(A.factB<0)   A.factB+=mod;
A.addB = (brt+brt)%mod;

brt = mod-(B1*Rp[t])%mod;
if(brt<0)   brt+=mod;
B.B = ((ll)((t*t)%mod)*brt)%mod;
B.factB = ((ll)(2*t+1)*brt)%mod;
B.addB = (brt+brt)%mod;

fwd[x]=fwd[x]+A;
if(y!=1)    fwd[parent[y]]=fwd[parent[y]]+ B;
}
void Update_backward(ll S,ll D,ll B1,ll t,ll g,int y,int x)
{
ll brt;

B.S = (S*Rp[t])%mod;
A.S = mod-(S*Rp[g])%mod;

B.factD = (D*Rp[t])%mod;
B.D     = (t*B.factD)%mod;

A.factD = mod - (D*Rp[g])%mod;
A.D     = (g*A.factD)%mod;

brt = (B1*Rp[t])%mod;
B.addB = brt + brt;
if ( B.addB >=mod ) B.addB -= mod;
B.factB = ((ll)(2*t-1)*brt)%mod;
if ( B.factB <0 ) B.factB += mod;
B.B = ((ll)((t*t)%mod)*brt)%mod;

brt = mod-(B1*Rp[g])%mod;
if(brt<0)   brt+=mod;
A.addB = brt + brt;
if ( A.addB >=mod ) A.addB -= mod;
A.factB = ((ll)(2*g-1)*brt)%mod;
if ( A.factB <0 ) A.factB += mod; 
A.B = ((ll)((g*g)%mod)*brt)%mod;

bck[y] = bck[y] + B;
bck[x] = bck[x] + A;
}
void solve()
{
ll S1,D1,B1,ans;
int Z,anc,x,y,a1,a2,d1,d2;
scanf("%d%lld",&N,&R);
Rinv=inverse(R,mod);
FOR(i,1,N-1){
scanf("%d%d",&x,&y);
myvector[x].push_back(y);
myvector[y].push_back(x);
}
parent[1]=1;
depth[1]=0;
dfs_pre(1);
process(N);

Rp[0]=1;
FOR(i,1,N)  Rp[i]=((ll)Rp[i-1]*(ll)R)%mod;

scanf("%d%d",&U,&Q);

while(U--){
scanf("%d%d%d%d%d%d",&a1,&d1,&a2,&d2,&x,&y);
S1 = ((ll)a1*(ll)a2)%mod;
D1 = ((ll)d1*(ll)a2 + (ll)d2*(ll)a1)%mod;
B1 = ((ll)d1*(ll)d2)%mod;
anc=lca(x,y);
Update_forward(S1,D1,B1,depth[x]-depth[anc]+1,x,anc);
Update_backward(S1,D1,B1,depth[y]+depth[x]-2*depth[anc],depth[x]-depth[anc],y,anc);
}
dfs_cal(1);
dfs_sum(1);
while(Q--){ 
scanf("%d%d",&x,&y);
anc=lca(x,y);
if(anc!=1)  ans=(sum[x]+sum[y]-sum[anc]-sum[parent[anc]])%mod;
else    ans=(sum[x]+sum[y]-sum[anc])%mod;
if(ans<0)   printf("%lld\n",ans+mod);
else    printf("%lld\n",ans);
}
}

int main()
{
init();
solve();
return 0;
}







In   Java  :





import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class E {
static InputStream is;
static PrintWriter out;
static String INPUT = "";

public static long invl(long a, long mod) {
long b = mod;
long p = 1, q = 0;
while (b > 0) {
long c = a / b;
long d;
d = a;
a = b;
b = d % b;
d = p;
p = q;
q = d - c * q;
}
return p < 0 ? p + mod : p;
}

public static void solve()
{
int mod = 1000000007;
int n = ni(), R = ni();
int IR = (int)invl(R, mod);
int[] from = new int[n-1];
int[] to = new int[n-1];
for(int i = 0;i < n-1;i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] g = packU(n, from, to);
int[][] pars = parents3(g, 0);
int[] par = pars[0], ord = pars[1], dep = pars[2];
int[][] spar = logstepParents(par);

long[] w = new long[n];
long[] wz = new long[n];
long[] wzz = new long[n];
long[] x = new long[n];
long[] xz = new long[n];
long[] xzz = new long[n];

long[] pr = new long[n+1];
long[] pir = new long[n+1];
pr[0] = 1;
pir[0] = 1;
for(int i = 1;i <= n;i++){
pr[i] = pr[i-1] * R % mod;
pir[i] = pir[i-1] * IR % mod;
}

int U = ni(), Q = ni();
for(int i = 0;i < U;i++){
long a1 = ni(), d1 = ni();
long a2 = ni(), d2 = ni();
int a = ni()-1, b = ni()-1;
int lca = lca2(a, b, spar, dep);

long p0 = a1*a2%mod;
long p1 = (d1*a2+d2*a1)%mod;
long p2 = d1*d2%mod;

w[a] += p0; w[a] %= mod;
wz[a] += p1; wz[a] %= mod;
wzz[a] += p2; wzz[a] %= mod;
if(par[lca] != -1){
int pl = par[lca];
int dal = dep[a] - dep[pl];
long rd = pr[dal];
w[pl] -= (p0+p1*dal+p2*dal%mod*dal)%mod*rd%mod;
if(w[pl] < 0)w[pl] += mod;
wz[pl] -= (p1+2*p2*dal)%mod*rd%mod;
if(wz[pl] < 0)wz[pl] += mod;
wzz[pl] -= p2*rd%mod;
if(wzz[pl] < 0)wzz[pl] += mod;
}

int dab = dep[a]+dep[b]-2*dep[lca];
long rx = pr[dab];
long px = (p0+p1*dab+p2*dab%mod*dab)%mod*rx%mod;
long pxz = (p1+2*p2*dab)%mod*rx%mod;
long pxzz = p2*rx%mod;
x[b] += px; x[b] %= mod;
xz[b] += pxz; xz[b] %= mod;
xzz[b] += pxzz; xzz[b] %= mod;
{
int dal = dep[a] - dep[lca];
long rd = pr[dal];
x[lca] -= (p0+p1*dal+p2*dal%mod*dal)%mod*rd%mod;
if(w[lca] < 0)x[lca] += mod;
xz[lca] -= (p1+2*p2*dal)%mod*rd%mod;
if(wz[lca] < 0)xz[lca] += mod;
xzz[lca] -= p2*rd%mod;
if(xzz[lca] < 0)xzz[lca] += mod;
}
}

for(int i = n-1;i >= 1;i--){
int cur = ord[i];
int p = par[cur];
w[p] += w[cur]*R+wz[cur]*R+wzz[cur]*R; w[p] %= mod;
wz[p] += wz[cur]*R+2L*wzz[cur]*R; wz[p] %= mod;
wzz[p] += wzz[cur]*R; wzz[p] %= mod;
x[p] += x[cur]*IR-xz[cur]*IR+xzz[cur]*IR; x[p] %= mod;
xz[p] += xz[cur]*IR-2L*xzz[cur]*IR; xz[p] %= mod;
xzz[p] += xzz[cur]*IR; xzz[p] %= mod;
}

long[] s = new long[n];
for(int i = 0;i < n;i++){
int cur = ord[i];
int p = par[cur];
if(p != -1){
s[cur] = (s[p] + w[cur] + x[cur]) % mod;
}else{
s[cur] = (w[cur] + x[cur]) % mod;
}
}

for(int i = 0;i < Q;i++){
int u = ni()-1, v = ni()-1;
int lca = lca2(u, v, spar, dep);
long ret = s[u] + s[v] - s[lca];
if(par[lca] != -1)ret -= s[par[lca]];
ret %= mod;
if(ret < 0)ret += mod;
out.println(ret);
}
}

public static long pow(long a, long n, long mod) {
//        a %= mod;
long ret = 1;
int x = 63 - Long.numberOfLeadingZeros(n);
for (; x >= 0; x--) {
ret = ret * ret % mod;
if (n << 63 - x < 0)
ret = ret * a % mod;
}
return ret;
}

public static int lca2(int a, int b, 

int[][] spar, int[] depth) {
if (depth[a] < depth[b]) {
b = ancestor(b, depth[b] - depth[a], spar);
} else if (depth[a] > depth[b]) {
a = ancestor(a, depth[a] - depth[b], spar);
}

if (a == b)
return a;
int sa = a, sb = b;
for (int low = 0, high = depth[a], t = 
Integer.highestOneBit(high), k = Integer
.numberOfTrailingZeros(t); t > 0; t >>>= 1, k--) {
if ((low ^ high) >= t) {
if (spar[k][sa] != spar[k][sb]) {
low |= t;
sa = spar[k][sa];
sb = spar[k][sb];
} else {
high = low | t - 1;
}
}
}
return spar[0][sa];
}

protected static int ancestor(int a, int m, int[][] spar) {
for (int i = 0; m > 0 && a != -1; m >>>= 1, i++) {
if ((m & 1) == 1)
a = spar[i][a];
}
return a;
}

public static int[][] logstepParents(int[] par) {
int n = par.length;
int m = Integer.numberOfTrailingZeros(
Integer.highestOneBit(n - 1)) + 1;
int[][] pars = new int[m][n];
pars[0] = par;
for (int j = 1; j < m; j++) {
for (int i = 0; i < n; i++) {
pars[j][i] = pars[j - 1][i] == -1 ? -1
: pars[j - 1][pars[j - 1][i]];
}
}
return pars;
}

public static int[][] parents3(
    int[][] g, int root) {
int n = g.length;
int[] par = new int[n];
Arrays.fill(par, -1);

int[] depth = new int[n];
depth[0] = 0;

int[] q = new int[n];
q[0] = root;
for (int p = 0, r = 1; p < r; p++) {
int cur = q[p];
for (int nex : g[cur]) {
if (par[cur] != nex) {
q[r++] = nex;
par[nex] = cur;
depth[nex] = depth[cur] + 1;
}
}
}
return new int[][] { par, q, depth };
}

static int[][] packU(int n, int[] from, int[] to) {
int[][] g = new int[n][];
int[] p = new int[n];
for (int f : from)
p[f]++;
for (int t : to)
p[t]++;
for (int i = 0; i < n; i++)
g[i] = new int[p[i]];
for (int i = 0; i < from.length; i++) {
g[from[i]][--p[from[i]]] = to[i];
g[to[i]][--p[to[i]]] = from[i];
}
return g;
}

public static void main(String[] args) throws Exception
{
long S = System.currentTimeMillis();
is = INPUT.isEmpty() ? System.in :
 new ByteArrayInputStream(INPUT.getBytes());
out = new PrintWriter(System.out);

solve();
out.flush();
long G = System.currentTimeMillis();
tr(G-S+"ms");
}

private static boolean eof()
{
if(lenbuf == -1)return true;
int lptr = ptrbuf;
while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))
return false;

try {
is.mark(1000);
while(true){
int b = is.read();
if(b == -1){
is.reset();
return true;
}else if(!isSpaceChar(b)){
is.reset();
return false;
}
}
} catch (IOException e) {
return true;
}
}

private static byte[] inbuf = new byte[1024];
static int lenbuf = 0, ptrbuf = 0;

private static int readByte()
{
if(lenbuf == -1)throw new InputMismatchException();
if(ptrbuf >= lenbuf){
ptrbuf = 0;
try { lenbuf = is.read(inbuf); } catch (IOException e) 
{ throw new InputMismatchException(); }
if(lenbuf <= 0)return -1;
}
return inbuf[ptrbuf++];
}

private static boolean isSpaceChar(int c)
 { return !(c >= 33 && c <= 126); }
private static int skip() 
{ int b; while((b = readByte()) != -1 && 
isSpaceChar(b)); return b; }

private static double nd() 
{ return Double.parseDouble(ns()); }
private static char nc() { return (char)skip(); }

private static String ns()
{
int b = skip();
StringBuilder sb = new StringBuilder();
while(!(isSpaceChar(b))){ 
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}

private static char[] ns(int n)
{
char[] buf = new char[n];
int b = skip(), p = 0;
while(p < n && !(isSpaceChar(b))){
buf[p++] = (char)b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}

private static char[][] nm(int n, int m)
{
char[][] map = new char[n][];
for(int i = 0;i < n;i++)map[i] = ns(m);
return map;
}

private static int[] na(int n)
{
int[] a = new int[n];
for(int i = 0;i < n;i++)a[i] = ni();
return a;
}

private static int ni()
{
int num = 0, b;
boolean minus = false;
while((b = readByte()) != -1 && !(
    (b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}

while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}

private static long nl()
{
long num = 0;
int b;
boolean minus = false;
while((b = readByte()) != -1 && !(
    (b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}

while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}

private static void tr(Object... o) 

{ if(INPUT.length() != 0)
System.out.println(Arrays.deepToString(o)); }
}
                        








View More Similar Problems

Components in a graph

There are 2 * N nodes in an undirected graph, and a number of edges connecting some nodes. In each edge, the first value will be between 1 and N, inclusive. The second node will be between N + 1 and , 2 * N inclusive. Given a list of edges, determine the size of the smallest and largest connected components that have or more nodes. A node can have any number of connections. The highest node valu

View Solution →

Kundu and Tree

Kundu is true tree lover. Tree is a connected graph having N vertices and N-1 edges. Today when he got a tree, he colored each edge with one of either red(r) or black(b) color. He is interested in knowing how many triplets(a,b,c) of vertices are there , such that, there is atleast one edge having red color on all the three paths i.e. from vertex a to b, vertex b to c and vertex c to a . Note that

View Solution →

Super Maximum Cost Queries

Victoria has a tree, T , consisting of N nodes numbered from 1 to N. Each edge from node Ui to Vi in tree T has an integer weight, Wi. Let's define the cost, C, of a path from some node X to some other node Y as the maximum weight ( W ) for any edge in the unique path from node X to Y node . Victoria wants your help processing Q queries on tree T, where each query contains 2 integers, L and

View Solution →

Contacts

We're going to make our own Contacts application! The application must perform two types of operations: 1 . add name, where name is a string denoting a contact name. This must store name as a new contact in the application. find partial, where partial is a string denoting a partial name to search the application for. It must count the number of contacts starting partial with and print the co

View Solution →

No Prefix Set

There is a given list of strings where each string contains only lowercase letters from a - j, inclusive. The set of strings is said to be a GOOD SET if no string is a prefix of another string. In this case, print GOOD SET. Otherwise, print BAD SET on the first line followed by the string being checked. Note If two strings are identical, they are prefixes of each other. Function Descriptio

View Solution →

Cube Summation

You are given a 3-D Matrix in which each block contains 0 initially. The first block is defined by the coordinate (1,1,1) and the last block is defined by the coordinate (N,N,N). There are two types of queries. UPDATE x y z W updates the value of block (x,y,z) to W. QUERY x1 y1 z1 x2 y2 z2 calculates the sum of the value of blocks whose x coordinate is between x1 and x2 (inclusive), y coor

View Solution →