Easy Addition


Problem Statement :


You are given a tree with N nodes and each has a value associated with it. You are given Q queries, each of which is either an update or a retrieval operation.

Initially all node values are zero.

The update query is of the format

a1 d1 a2 d2 A B
This means you'd have to add ( a1 + z * d1) * ( a2 - z * d2 ) * R^z  in all nodes in the path from A to B where  is the distance between the node and A.

The retrieval query is of the format

i j
You need to return the sum of the node values lying in the path from node i to node j modulo 1000000007.

Note:

First all update queries are given and then all retrieval queries.
Distance between 2 nodes is the shortest path length between them taking each edge weight as 1.
Input Format

The first line contains two integers (N and R respectively) separated by a space.

In the next N-1 lines, the ith line describes the ith edge: a line with two integers x y separated by a single space denotes an edge between nodes x and y.

The next line contains 2 space separated integers (U and Q respectively) representing the number of Update and Query operations to follow.

U lines follow. Each of the next U lines contains 6 space separated integers (a1,d1,a2,d2,A and B respectively).

Each of the next Q lines contains 2 space separated integers, i and j respectively.


Output Format

It contains exactly Q lines and each line containing the answer of the ith query.

Constraints

2 <= N <= 105
2 <= R <= 109
1 <= U <= 105
1 <= Q <= 105
1 <= a1,a2,d1,d2 <= 108
1 <= x, y, i, j, A, B <= N

Note For the update operation, x can be equal to y and for the query operation, i can be equal to j.



Solution :



title-img


                            Solution in C :

In   C++ :








#include<stdio.h>
#include<string.h>
#include<iostream>
#include<algorithm>
#include<stdlib.h>
#include<set>
#include<map>
#include<bitset>
#include<vector>
#include<assert.h>
using namespace std;
#define FOR(i,a,b) for(int i=(a);i<=(b);i++)
typedef long long int ll;
const int mod = 1000000007;
vector<int> myvector[100001];
int N,parent[100001],depth[100001],Q,U;
ll R,Rinv,Rp[100001],sum[100001];
int store[100001];
struct node{
ll S,D,B,factD,factB,addB;
node(ll a=0,ll b=0,ll c=0,ll d=0,ll e=0,ll f=0){
S=a,D=b,B=c;
factD=d;
factB=e;
addB=f;
}
node operator + (const node &x) const{
ll a,b,c,d,e,f;
a = (x.S + S);
b = (x.D + D);
c = (x.B + B);
d = (x.factD + factD);
e = (x.factB + factB);
f = (x.addB + addB);
a= (a>mod)?(a-mod):a;
b= (b>mod)?(b-mod):b;
c= (c>mod)?(c-mod):c;
d= (d>mod)?(d-mod):d;
e= (e>mod)?(e-mod):e;
f= (f>mod)?(f-mod):f;
return node(a,b,c,d,e,f);
}
}fwd[100001],bck[100001],A,B;
ll inverse(ll a,ll b)
{
ll Remainder,p0=0,p1=1,pcurr=1,q,m=b;
while(a!=0){
Remainder=b%a;
q=b/a;
if(Remainder!=0){
pcurr=p0-(p1*q)%m;
if(pcurr<0)
pcurr+=m;
p0=p1;
p1=pcurr;
}
b=a;
a=Remainder;
}
return pcurr;   
}
void dfs_pre(int root)
{
for(vector<int>::iterator it=myvector[root].begin();it!=myvector[root].end();it++){
if(parent[root]==*it)   continue;
parent[*it]=root;
depth[*it]=depth[root]+1;
dfs_pre(*it);
}
}

void dfs_cal(int root)          
{
for(vector<int>::iterator it=myvector[root].begin();it!=myvector[root].end();it++){
if(parent[root]==*it)   continue;
dfs_cal(*it);

fwd[root].S = (fwd[root].S + fwd[*it].S*R)%mod;
fwd[root].D = (fwd[root].D + (fwd[*it].D + fwd[*it].factD)*R)%mod;
fwd[root].factD = (fwd[root].factD + fwd[*it].factD*R)%mod;
fwd[root].B = (fwd[root].B +  (fwd[*it].B + fwd[*it].factB)*R) %mod;
fwd[root].factB = (fwd[root].factB + (fwd[*it].factB+fwd[*it].addB)*R)%mod;
fwd[root].addB = ( fwd[root].addB + fwd[*it].addB*R)%mod;

bck[root].S = (bck[*it].S * Rinv  + bck[root].S)%mod;
bck[root].D = (bck[root].D + (bck[*it].D - bck[*it].factD)*Rinv )%mod;
if (bck[root].D<0)  bck[root].D+=mod;
bck[root].factD = (bck[root].factD + bck[*it].factD*Rinv)%mod; 

bck[root].B = (bck[root].B + (bck[*it].B - bck[*it].factB)*Rinv) %mod; 
if(bck[root].B<0)   bck[root].B+=mod;
bck[root].factB = (bck[root].factB + (bck[*it].factB-bck[*it].addB)*Rinv)%mod;
if(bck[root].factB<0)    bck[root].factB+=mod;
bck[root].addB  = ( bck[root].addB + bck[*it].addB*Rinv)%mod;

}

sum[root] = (sum[root] + fwd[root].S+fwd[root].D+fwd[root].B + bck[root].S+bck[root].D+bck[root].B)%mod;
}

void dfs_sum(int root)          
{
for(vector<int>::iterator it=myvector[root].begin();it!=myvector[root].end();it++){
if(parent[root]==*it)   continue;
sum[*it]=(sum[*it]+sum[root])%mod;
dfs_sum(*it);
}
}
int Root[100001][18];
void init()
{
store[0]=0;store[1]=0;store[2]=1;
int cmp=4;
FOR(i,3,100000){
if(cmp>i)       store[i]=store[i-1];
else{
store[i]=store[i-1]+1;
cmp<<=1;
}
}
}
void process(int N)
{
memset(Root,-1,sizeof(Root));
for(int i=1;i<=N;i++)   Root[i][0]=parent[i];
for(int i=1;(1<<i)<=N;i++)
for(int j=1;j<=N;j++)
if(Root[j][i-1]!=-1)
Root[j][i]=Root[Root[j][i-1]][i-1];
}
int lca(int p,int q)
{
int temp;
if(depth[p]>depth[q])   swap(p,q);
int steps=store[depth[q]];
for(int i=steps;i>=0;i--)
if(depth[q]-(1<<i) >= depth[p])
q=Root[q][i];
if(p==q)    return p;
for(int i=steps;i>=0;i--){
if(Root[p][i]!=Root[q][i])
p=Root[p][i],q=Root[q][i];
}
return parent[p];
}
void Update_forward(ll S,ll D,ll B1,ll t,int x,int y)
{
ll brt;

A.S = S;
B.S = mod - (S*Rp[t])%mod;

A.factD = D;
A.D = 0;
B.factD = mod-(D*Rp[t])%mod;
if(B.factD<0)   B.factD+=mod;
B.D = (B.factD*t)%mod;

brt = B1;
A.B = 0;
A.factB = brt;
if(A.factB<0)   A.factB+=mod;
A.addB = (brt+brt)%mod;

brt = mod-(B1*Rp[t])%mod;
if(brt<0)   brt+=mod;
B.B = ((ll)((t*t)%mod)*brt)%mod;
B.factB = ((ll)(2*t+1)*brt)%mod;
B.addB = (brt+brt)%mod;

fwd[x]=fwd[x]+A;
if(y!=1)    fwd[parent[y]]=fwd[parent[y]]+ B;
}
void Update_backward(ll S,ll D,ll B1,ll t,ll g,int y,int x)
{
ll brt;

B.S = (S*Rp[t])%mod;
A.S = mod-(S*Rp[g])%mod;

B.factD = (D*Rp[t])%mod;
B.D     = (t*B.factD)%mod;

A.factD = mod - (D*Rp[g])%mod;
A.D     = (g*A.factD)%mod;

brt = (B1*Rp[t])%mod;
B.addB = brt + brt;
if ( B.addB >=mod ) B.addB -= mod;
B.factB = ((ll)(2*t-1)*brt)%mod;
if ( B.factB <0 ) B.factB += mod;
B.B = ((ll)((t*t)%mod)*brt)%mod;

brt = mod-(B1*Rp[g])%mod;
if(brt<0)   brt+=mod;
A.addB = brt + brt;
if ( A.addB >=mod ) A.addB -= mod;
A.factB = ((ll)(2*g-1)*brt)%mod;
if ( A.factB <0 ) A.factB += mod; 
A.B = ((ll)((g*g)%mod)*brt)%mod;

bck[y] = bck[y] + B;
bck[x] = bck[x] + A;
}
void solve()
{
ll S1,D1,B1,ans;
int Z,anc,x,y,a1,a2,d1,d2;
scanf("%d%lld",&N,&R);
Rinv=inverse(R,mod);
FOR(i,1,N-1){
scanf("%d%d",&x,&y);
myvector[x].push_back(y);
myvector[y].push_back(x);
}
parent[1]=1;
depth[1]=0;
dfs_pre(1);
process(N);

Rp[0]=1;
FOR(i,1,N)  Rp[i]=((ll)Rp[i-1]*(ll)R)%mod;

scanf("%d%d",&U,&Q);

while(U--){
scanf("%d%d%d%d%d%d",&a1,&d1,&a2,&d2,&x,&y);
S1 = ((ll)a1*(ll)a2)%mod;
D1 = ((ll)d1*(ll)a2 + (ll)d2*(ll)a1)%mod;
B1 = ((ll)d1*(ll)d2)%mod;
anc=lca(x,y);
Update_forward(S1,D1,B1,depth[x]-depth[anc]+1,x,anc);
Update_backward(S1,D1,B1,depth[y]+depth[x]-2*depth[anc],depth[x]-depth[anc],y,anc);
}
dfs_cal(1);
dfs_sum(1);
while(Q--){ 
scanf("%d%d",&x,&y);
anc=lca(x,y);
if(anc!=1)  ans=(sum[x]+sum[y]-sum[anc]-sum[parent[anc]])%mod;
else    ans=(sum[x]+sum[y]-sum[anc])%mod;
if(ans<0)   printf("%lld\n",ans+mod);
else    printf("%lld\n",ans);
}
}

int main()
{
init();
solve();
return 0;
}







In   Java  :





import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class E {
static InputStream is;
static PrintWriter out;
static String INPUT = "";

public static long invl(long a, long mod) {
long b = mod;
long p = 1, q = 0;
while (b > 0) {
long c = a / b;
long d;
d = a;
a = b;
b = d % b;
d = p;
p = q;
q = d - c * q;
}
return p < 0 ? p + mod : p;
}

public static void solve()
{
int mod = 1000000007;
int n = ni(), R = ni();
int IR = (int)invl(R, mod);
int[] from = new int[n-1];
int[] to = new int[n-1];
for(int i = 0;i < n-1;i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] g = packU(n, from, to);
int[][] pars = parents3(g, 0);
int[] par = pars[0], ord = pars[1], dep = pars[2];
int[][] spar = logstepParents(par);

long[] w = new long[n];
long[] wz = new long[n];
long[] wzz = new long[n];
long[] x = new long[n];
long[] xz = new long[n];
long[] xzz = new long[n];

long[] pr = new long[n+1];
long[] pir = new long[n+1];
pr[0] = 1;
pir[0] = 1;
for(int i = 1;i <= n;i++){
pr[i] = pr[i-1] * R % mod;
pir[i] = pir[i-1] * IR % mod;
}

int U = ni(), Q = ni();
for(int i = 0;i < U;i++){
long a1 = ni(), d1 = ni();
long a2 = ni(), d2 = ni();
int a = ni()-1, b = ni()-1;
int lca = lca2(a, b, spar, dep);

long p0 = a1*a2%mod;
long p1 = (d1*a2+d2*a1)%mod;
long p2 = d1*d2%mod;

w[a] += p0; w[a] %= mod;
wz[a] += p1; wz[a] %= mod;
wzz[a] += p2; wzz[a] %= mod;
if(par[lca] != -1){
int pl = par[lca];
int dal = dep[a] - dep[pl];
long rd = pr[dal];
w[pl] -= (p0+p1*dal+p2*dal%mod*dal)%mod*rd%mod;
if(w[pl] < 0)w[pl] += mod;
wz[pl] -= (p1+2*p2*dal)%mod*rd%mod;
if(wz[pl] < 0)wz[pl] += mod;
wzz[pl] -= p2*rd%mod;
if(wzz[pl] < 0)wzz[pl] += mod;
}

int dab = dep[a]+dep[b]-2*dep[lca];
long rx = pr[dab];
long px = (p0+p1*dab+p2*dab%mod*dab)%mod*rx%mod;
long pxz = (p1+2*p2*dab)%mod*rx%mod;
long pxzz = p2*rx%mod;
x[b] += px; x[b] %= mod;
xz[b] += pxz; xz[b] %= mod;
xzz[b] += pxzz; xzz[b] %= mod;
{
int dal = dep[a] - dep[lca];
long rd = pr[dal];
x[lca] -= (p0+p1*dal+p2*dal%mod*dal)%mod*rd%mod;
if(w[lca] < 0)x[lca] += mod;
xz[lca] -= (p1+2*p2*dal)%mod*rd%mod;
if(wz[lca] < 0)xz[lca] += mod;
xzz[lca] -= p2*rd%mod;
if(xzz[lca] < 0)xzz[lca] += mod;
}
}

for(int i = n-1;i >= 1;i--){
int cur = ord[i];
int p = par[cur];
w[p] += w[cur]*R+wz[cur]*R+wzz[cur]*R; w[p] %= mod;
wz[p] += wz[cur]*R+2L*wzz[cur]*R; wz[p] %= mod;
wzz[p] += wzz[cur]*R; wzz[p] %= mod;
x[p] += x[cur]*IR-xz[cur]*IR+xzz[cur]*IR; x[p] %= mod;
xz[p] += xz[cur]*IR-2L*xzz[cur]*IR; xz[p] %= mod;
xzz[p] += xzz[cur]*IR; xzz[p] %= mod;
}

long[] s = new long[n];
for(int i = 0;i < n;i++){
int cur = ord[i];
int p = par[cur];
if(p != -1){
s[cur] = (s[p] + w[cur] + x[cur]) % mod;
}else{
s[cur] = (w[cur] + x[cur]) % mod;
}
}

for(int i = 0;i < Q;i++){
int u = ni()-1, v = ni()-1;
int lca = lca2(u, v, spar, dep);
long ret = s[u] + s[v] - s[lca];
if(par[lca] != -1)ret -= s[par[lca]];
ret %= mod;
if(ret < 0)ret += mod;
out.println(ret);
}
}

public static long pow(long a, long n, long mod) {
//        a %= mod;
long ret = 1;
int x = 63 - Long.numberOfLeadingZeros(n);
for (; x >= 0; x--) {
ret = ret * ret % mod;
if (n << 63 - x < 0)
ret = ret * a % mod;
}
return ret;
}

public static int lca2(int a, int b, 

int[][] spar, int[] depth) {
if (depth[a] < depth[b]) {
b = ancestor(b, depth[b] - depth[a], spar);
} else if (depth[a] > depth[b]) {
a = ancestor(a, depth[a] - depth[b], spar);
}

if (a == b)
return a;
int sa = a, sb = b;
for (int low = 0, high = depth[a], t = 
Integer.highestOneBit(high), k = Integer
.numberOfTrailingZeros(t); t > 0; t >>>= 1, k--) {
if ((low ^ high) >= t) {
if (spar[k][sa] != spar[k][sb]) {
low |= t;
sa = spar[k][sa];
sb = spar[k][sb];
} else {
high = low | t - 1;
}
}
}
return spar[0][sa];
}

protected static int ancestor(int a, int m, int[][] spar) {
for (int i = 0; m > 0 && a != -1; m >>>= 1, i++) {
if ((m & 1) == 1)
a = spar[i][a];
}
return a;
}

public static int[][] logstepParents(int[] par) {
int n = par.length;
int m = Integer.numberOfTrailingZeros(
Integer.highestOneBit(n - 1)) + 1;
int[][] pars = new int[m][n];
pars[0] = par;
for (int j = 1; j < m; j++) {
for (int i = 0; i < n; i++) {
pars[j][i] = pars[j - 1][i] == -1 ? -1
: pars[j - 1][pars[j - 1][i]];
}
}
return pars;
}

public static int[][] parents3(
    int[][] g, int root) {
int n = g.length;
int[] par = new int[n];
Arrays.fill(par, -1);

int[] depth = new int[n];
depth[0] = 0;

int[] q = new int[n];
q[0] = root;
for (int p = 0, r = 1; p < r; p++) {
int cur = q[p];
for (int nex : g[cur]) {
if (par[cur] != nex) {
q[r++] = nex;
par[nex] = cur;
depth[nex] = depth[cur] + 1;
}
}
}
return new int[][] { par, q, depth };
}

static int[][] packU(int n, int[] from, int[] to) {
int[][] g = new int[n][];
int[] p = new int[n];
for (int f : from)
p[f]++;
for (int t : to)
p[t]++;
for (int i = 0; i < n; i++)
g[i] = new int[p[i]];
for (int i = 0; i < from.length; i++) {
g[from[i]][--p[from[i]]] = to[i];
g[to[i]][--p[to[i]]] = from[i];
}
return g;
}

public static void main(String[] args) throws Exception
{
long S = System.currentTimeMillis();
is = INPUT.isEmpty() ? System.in :
 new ByteArrayInputStream(INPUT.getBytes());
out = new PrintWriter(System.out);

solve();
out.flush();
long G = System.currentTimeMillis();
tr(G-S+"ms");
}

private static boolean eof()
{
if(lenbuf == -1)return true;
int lptr = ptrbuf;
while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))
return false;

try {
is.mark(1000);
while(true){
int b = is.read();
if(b == -1){
is.reset();
return true;
}else if(!isSpaceChar(b)){
is.reset();
return false;
}
}
} catch (IOException e) {
return true;
}
}

private static byte[] inbuf = new byte[1024];
static int lenbuf = 0, ptrbuf = 0;

private static int readByte()
{
if(lenbuf == -1)throw new InputMismatchException();
if(ptrbuf >= lenbuf){
ptrbuf = 0;
try { lenbuf = is.read(inbuf); } catch (IOException e) 
{ throw new InputMismatchException(); }
if(lenbuf <= 0)return -1;
}
return inbuf[ptrbuf++];
}

private static boolean isSpaceChar(int c)
 { return !(c >= 33 && c <= 126); }
private static int skip() 
{ int b; while((b = readByte()) != -1 && 
isSpaceChar(b)); return b; }

private static double nd() 
{ return Double.parseDouble(ns()); }
private static char nc() { return (char)skip(); }

private static String ns()
{
int b = skip();
StringBuilder sb = new StringBuilder();
while(!(isSpaceChar(b))){ 
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}

private static char[] ns(int n)
{
char[] buf = new char[n];
int b = skip(), p = 0;
while(p < n && !(isSpaceChar(b))){
buf[p++] = (char)b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}

private static char[][] nm(int n, int m)
{
char[][] map = new char[n][];
for(int i = 0;i < n;i++)map[i] = ns(m);
return map;
}

private static int[] na(int n)
{
int[] a = new int[n];
for(int i = 0;i < n;i++)a[i] = ni();
return a;
}

private static int ni()
{
int num = 0, b;
boolean minus = false;
while((b = readByte()) != -1 && !(
    (b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}

while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}

private static long nl()
{
long num = 0;
int b;
boolean minus = false;
while((b = readByte()) != -1 && !(
    (b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}

while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}

private static void tr(Object... o) 

{ if(INPUT.length() != 0)
System.out.println(Arrays.deepToString(o)); }
}
                        








View More Similar Problems

Contacts

We're going to make our own Contacts application! The application must perform two types of operations: 1 . add name, where name is a string denoting a contact name. This must store name as a new contact in the application. find partial, where partial is a string denoting a partial name to search the application for. It must count the number of contacts starting partial with and print the co

View Solution →

No Prefix Set

There is a given list of strings where each string contains only lowercase letters from a - j, inclusive. The set of strings is said to be a GOOD SET if no string is a prefix of another string. In this case, print GOOD SET. Otherwise, print BAD SET on the first line followed by the string being checked. Note If two strings are identical, they are prefixes of each other. Function Descriptio

View Solution →

Cube Summation

You are given a 3-D Matrix in which each block contains 0 initially. The first block is defined by the coordinate (1,1,1) and the last block is defined by the coordinate (N,N,N). There are two types of queries. UPDATE x y z W updates the value of block (x,y,z) to W. QUERY x1 y1 z1 x2 y2 z2 calculates the sum of the value of blocks whose x coordinate is between x1 and x2 (inclusive), y coor

View Solution →

Direct Connections

Enter-View ( EV ) is a linear, street-like country. By linear, we mean all the cities of the country are placed on a single straight line - the x -axis. Thus every city's position can be defined by a single coordinate, xi, the distance from the left borderline of the country. You can treat all cities as single points. Unfortunately, the dictator of telecommunication of EV (Mr. S. Treat Jr.) do

View Solution →

Subsequence Weighting

A subsequence of a sequence is a sequence which is obtained by deleting zero or more elements from the sequence. You are given a sequence A in which every element is a pair of integers i.e A = [(a1, w1), (a2, w2),..., (aN, wN)]. For a subseqence B = [(b1, v1), (b2, v2), ...., (bM, vM)] of the given sequence : We call it increasing if for every i (1 <= i < M ) , bi < bi+1. Weight(B) =

View Solution →

Kindergarten Adventures

Meera teaches a class of n students, and every day in her classroom is an adventure. Today is drawing day! The students are sitting around a round table, and they are numbered from 1 to n in the clockwise direction. This means that the students are numbered 1, 2, 3, . . . , n-1, n, and students 1 and n are sitting next to each other. After letting the students draw for a certain period of ti

View Solution →