Easy Addition
Problem Statement :
You are given a tree with N nodes and each has a value associated with it. You are given Q queries, each of which is either an update or a retrieval operation. Initially all node values are zero. The update query is of the format a1 d1 a2 d2 A B This means you'd have to add ( a1 + z * d1) * ( a2 - z * d2 ) * R^z in all nodes in the path from A to B where is the distance between the node and A. The retrieval query is of the format i j You need to return the sum of the node values lying in the path from node i to node j modulo 1000000007. Note: First all update queries are given and then all retrieval queries. Distance between 2 nodes is the shortest path length between them taking each edge weight as 1. Input Format The first line contains two integers (N and R respectively) separated by a space. In the next N-1 lines, the ith line describes the ith edge: a line with two integers x y separated by a single space denotes an edge between nodes x and y. The next line contains 2 space separated integers (U and Q respectively) representing the number of Update and Query operations to follow. U lines follow. Each of the next U lines contains 6 space separated integers (a1,d1,a2,d2,A and B respectively). Each of the next Q lines contains 2 space separated integers, i and j respectively. Output Format It contains exactly Q lines and each line containing the answer of the ith query. Constraints 2 <= N <= 105 2 <= R <= 109 1 <= U <= 105 1 <= Q <= 105 1 <= a1,a2,d1,d2 <= 108 1 <= x, y, i, j, A, B <= N Note For the update operation, x can be equal to y and for the query operation, i can be equal to j.
Solution :
Solution in C :
In C++ :
#include<stdio.h>
#include<string.h>
#include<iostream>
#include<algorithm>
#include<stdlib.h>
#include<set>
#include<map>
#include<bitset>
#include<vector>
#include<assert.h>
using namespace std;
#define FOR(i,a,b) for(int i=(a);i<=(b);i++)
typedef long long int ll;
const int mod = 1000000007;
vector<int> myvector[100001];
int N,parent[100001],depth[100001],Q,U;
ll R,Rinv,Rp[100001],sum[100001];
int store[100001];
struct node{
ll S,D,B,factD,factB,addB;
node(ll a=0,ll b=0,ll c=0,ll d=0,ll e=0,ll f=0){
S=a,D=b,B=c;
factD=d;
factB=e;
addB=f;
}
node operator + (const node &x) const{
ll a,b,c,d,e,f;
a = (x.S + S);
b = (x.D + D);
c = (x.B + B);
d = (x.factD + factD);
e = (x.factB + factB);
f = (x.addB + addB);
a= (a>mod)?(a-mod):a;
b= (b>mod)?(b-mod):b;
c= (c>mod)?(c-mod):c;
d= (d>mod)?(d-mod):d;
e= (e>mod)?(e-mod):e;
f= (f>mod)?(f-mod):f;
return node(a,b,c,d,e,f);
}
}fwd[100001],bck[100001],A,B;
ll inverse(ll a,ll b)
{
ll Remainder,p0=0,p1=1,pcurr=1,q,m=b;
while(a!=0){
Remainder=b%a;
q=b/a;
if(Remainder!=0){
pcurr=p0-(p1*q)%m;
if(pcurr<0)
pcurr+=m;
p0=p1;
p1=pcurr;
}
b=a;
a=Remainder;
}
return pcurr;
}
void dfs_pre(int root)
{
for(vector<int>::iterator it=myvector[root].begin();it!=myvector[root].end();it++){
if(parent[root]==*it) continue;
parent[*it]=root;
depth[*it]=depth[root]+1;
dfs_pre(*it);
}
}
void dfs_cal(int root)
{
for(vector<int>::iterator it=myvector[root].begin();it!=myvector[root].end();it++){
if(parent[root]==*it) continue;
dfs_cal(*it);
fwd[root].S = (fwd[root].S + fwd[*it].S*R)%mod;
fwd[root].D = (fwd[root].D + (fwd[*it].D + fwd[*it].factD)*R)%mod;
fwd[root].factD = (fwd[root].factD + fwd[*it].factD*R)%mod;
fwd[root].B = (fwd[root].B + (fwd[*it].B + fwd[*it].factB)*R) %mod;
fwd[root].factB = (fwd[root].factB + (fwd[*it].factB+fwd[*it].addB)*R)%mod;
fwd[root].addB = ( fwd[root].addB + fwd[*it].addB*R)%mod;
bck[root].S = (bck[*it].S * Rinv + bck[root].S)%mod;
bck[root].D = (bck[root].D + (bck[*it].D - bck[*it].factD)*Rinv )%mod;
if (bck[root].D<0) bck[root].D+=mod;
bck[root].factD = (bck[root].factD + bck[*it].factD*Rinv)%mod;
bck[root].B = (bck[root].B + (bck[*it].B - bck[*it].factB)*Rinv) %mod;
if(bck[root].B<0) bck[root].B+=mod;
bck[root].factB = (bck[root].factB + (bck[*it].factB-bck[*it].addB)*Rinv)%mod;
if(bck[root].factB<0) bck[root].factB+=mod;
bck[root].addB = ( bck[root].addB + bck[*it].addB*Rinv)%mod;
}
sum[root] = (sum[root] + fwd[root].S+fwd[root].D+fwd[root].B + bck[root].S+bck[root].D+bck[root].B)%mod;
}
void dfs_sum(int root)
{
for(vector<int>::iterator it=myvector[root].begin();it!=myvector[root].end();it++){
if(parent[root]==*it) continue;
sum[*it]=(sum[*it]+sum[root])%mod;
dfs_sum(*it);
}
}
int Root[100001][18];
void init()
{
store[0]=0;store[1]=0;store[2]=1;
int cmp=4;
FOR(i,3,100000){
if(cmp>i) store[i]=store[i-1];
else{
store[i]=store[i-1]+1;
cmp<<=1;
}
}
}
void process(int N)
{
memset(Root,-1,sizeof(Root));
for(int i=1;i<=N;i++) Root[i][0]=parent[i];
for(int i=1;(1<<i)<=N;i++)
for(int j=1;j<=N;j++)
if(Root[j][i-1]!=-1)
Root[j][i]=Root[Root[j][i-1]][i-1];
}
int lca(int p,int q)
{
int temp;
if(depth[p]>depth[q]) swap(p,q);
int steps=store[depth[q]];
for(int i=steps;i>=0;i--)
if(depth[q]-(1<<i) >= depth[p])
q=Root[q][i];
if(p==q) return p;
for(int i=steps;i>=0;i--){
if(Root[p][i]!=Root[q][i])
p=Root[p][i],q=Root[q][i];
}
return parent[p];
}
void Update_forward(ll S,ll D,ll B1,ll t,int x,int y)
{
ll brt;
A.S = S;
B.S = mod - (S*Rp[t])%mod;
A.factD = D;
A.D = 0;
B.factD = mod-(D*Rp[t])%mod;
if(B.factD<0) B.factD+=mod;
B.D = (B.factD*t)%mod;
brt = B1;
A.B = 0;
A.factB = brt;
if(A.factB<0) A.factB+=mod;
A.addB = (brt+brt)%mod;
brt = mod-(B1*Rp[t])%mod;
if(brt<0) brt+=mod;
B.B = ((ll)((t*t)%mod)*brt)%mod;
B.factB = ((ll)(2*t+1)*brt)%mod;
B.addB = (brt+brt)%mod;
fwd[x]=fwd[x]+A;
if(y!=1) fwd[parent[y]]=fwd[parent[y]]+ B;
}
void Update_backward(ll S,ll D,ll B1,ll t,ll g,int y,int x)
{
ll brt;
B.S = (S*Rp[t])%mod;
A.S = mod-(S*Rp[g])%mod;
B.factD = (D*Rp[t])%mod;
B.D = (t*B.factD)%mod;
A.factD = mod - (D*Rp[g])%mod;
A.D = (g*A.factD)%mod;
brt = (B1*Rp[t])%mod;
B.addB = brt + brt;
if ( B.addB >=mod ) B.addB -= mod;
B.factB = ((ll)(2*t-1)*brt)%mod;
if ( B.factB <0 ) B.factB += mod;
B.B = ((ll)((t*t)%mod)*brt)%mod;
brt = mod-(B1*Rp[g])%mod;
if(brt<0) brt+=mod;
A.addB = brt + brt;
if ( A.addB >=mod ) A.addB -= mod;
A.factB = ((ll)(2*g-1)*brt)%mod;
if ( A.factB <0 ) A.factB += mod;
A.B = ((ll)((g*g)%mod)*brt)%mod;
bck[y] = bck[y] + B;
bck[x] = bck[x] + A;
}
void solve()
{
ll S1,D1,B1,ans;
int Z,anc,x,y,a1,a2,d1,d2;
scanf("%d%lld",&N,&R);
Rinv=inverse(R,mod);
FOR(i,1,N-1){
scanf("%d%d",&x,&y);
myvector[x].push_back(y);
myvector[y].push_back(x);
}
parent[1]=1;
depth[1]=0;
dfs_pre(1);
process(N);
Rp[0]=1;
FOR(i,1,N) Rp[i]=((ll)Rp[i-1]*(ll)R)%mod;
scanf("%d%d",&U,&Q);
while(U--){
scanf("%d%d%d%d%d%d",&a1,&d1,&a2,&d2,&x,&y);
S1 = ((ll)a1*(ll)a2)%mod;
D1 = ((ll)d1*(ll)a2 + (ll)d2*(ll)a1)%mod;
B1 = ((ll)d1*(ll)d2)%mod;
anc=lca(x,y);
Update_forward(S1,D1,B1,depth[x]-depth[anc]+1,x,anc);
Update_backward(S1,D1,B1,depth[y]+depth[x]-2*depth[anc],depth[x]-depth[anc],y,anc);
}
dfs_cal(1);
dfs_sum(1);
while(Q--){
scanf("%d%d",&x,&y);
anc=lca(x,y);
if(anc!=1) ans=(sum[x]+sum[y]-sum[anc]-sum[parent[anc]])%mod;
else ans=(sum[x]+sum[y]-sum[anc])%mod;
if(ans<0) printf("%lld\n",ans+mod);
else printf("%lld\n",ans);
}
}
int main()
{
init();
solve();
return 0;
}
In Java :
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
public class E {
static InputStream is;
static PrintWriter out;
static String INPUT = "";
public static long invl(long a, long mod) {
long b = mod;
long p = 1, q = 0;
while (b > 0) {
long c = a / b;
long d;
d = a;
a = b;
b = d % b;
d = p;
p = q;
q = d - c * q;
}
return p < 0 ? p + mod : p;
}
public static void solve()
{
int mod = 1000000007;
int n = ni(), R = ni();
int IR = (int)invl(R, mod);
int[] from = new int[n-1];
int[] to = new int[n-1];
for(int i = 0;i < n-1;i++){
from[i] = ni()-1;
to[i] = ni()-1;
}
int[][] g = packU(n, from, to);
int[][] pars = parents3(g, 0);
int[] par = pars[0], ord = pars[1], dep = pars[2];
int[][] spar = logstepParents(par);
long[] w = new long[n];
long[] wz = new long[n];
long[] wzz = new long[n];
long[] x = new long[n];
long[] xz = new long[n];
long[] xzz = new long[n];
long[] pr = new long[n+1];
long[] pir = new long[n+1];
pr[0] = 1;
pir[0] = 1;
for(int i = 1;i <= n;i++){
pr[i] = pr[i-1] * R % mod;
pir[i] = pir[i-1] * IR % mod;
}
int U = ni(), Q = ni();
for(int i = 0;i < U;i++){
long a1 = ni(), d1 = ni();
long a2 = ni(), d2 = ni();
int a = ni()-1, b = ni()-1;
int lca = lca2(a, b, spar, dep);
long p0 = a1*a2%mod;
long p1 = (d1*a2+d2*a1)%mod;
long p2 = d1*d2%mod;
w[a] += p0; w[a] %= mod;
wz[a] += p1; wz[a] %= mod;
wzz[a] += p2; wzz[a] %= mod;
if(par[lca] != -1){
int pl = par[lca];
int dal = dep[a] - dep[pl];
long rd = pr[dal];
w[pl] -= (p0+p1*dal+p2*dal%mod*dal)%mod*rd%mod;
if(w[pl] < 0)w[pl] += mod;
wz[pl] -= (p1+2*p2*dal)%mod*rd%mod;
if(wz[pl] < 0)wz[pl] += mod;
wzz[pl] -= p2*rd%mod;
if(wzz[pl] < 0)wzz[pl] += mod;
}
int dab = dep[a]+dep[b]-2*dep[lca];
long rx = pr[dab];
long px = (p0+p1*dab+p2*dab%mod*dab)%mod*rx%mod;
long pxz = (p1+2*p2*dab)%mod*rx%mod;
long pxzz = p2*rx%mod;
x[b] += px; x[b] %= mod;
xz[b] += pxz; xz[b] %= mod;
xzz[b] += pxzz; xzz[b] %= mod;
{
int dal = dep[a] - dep[lca];
long rd = pr[dal];
x[lca] -= (p0+p1*dal+p2*dal%mod*dal)%mod*rd%mod;
if(w[lca] < 0)x[lca] += mod;
xz[lca] -= (p1+2*p2*dal)%mod*rd%mod;
if(wz[lca] < 0)xz[lca] += mod;
xzz[lca] -= p2*rd%mod;
if(xzz[lca] < 0)xzz[lca] += mod;
}
}
for(int i = n-1;i >= 1;i--){
int cur = ord[i];
int p = par[cur];
w[p] += w[cur]*R+wz[cur]*R+wzz[cur]*R; w[p] %= mod;
wz[p] += wz[cur]*R+2L*wzz[cur]*R; wz[p] %= mod;
wzz[p] += wzz[cur]*R; wzz[p] %= mod;
x[p] += x[cur]*IR-xz[cur]*IR+xzz[cur]*IR; x[p] %= mod;
xz[p] += xz[cur]*IR-2L*xzz[cur]*IR; xz[p] %= mod;
xzz[p] += xzz[cur]*IR; xzz[p] %= mod;
}
long[] s = new long[n];
for(int i = 0;i < n;i++){
int cur = ord[i];
int p = par[cur];
if(p != -1){
s[cur] = (s[p] + w[cur] + x[cur]) % mod;
}else{
s[cur] = (w[cur] + x[cur]) % mod;
}
}
for(int i = 0;i < Q;i++){
int u = ni()-1, v = ni()-1;
int lca = lca2(u, v, spar, dep);
long ret = s[u] + s[v] - s[lca];
if(par[lca] != -1)ret -= s[par[lca]];
ret %= mod;
if(ret < 0)ret += mod;
out.println(ret);
}
}
public static long pow(long a, long n, long mod) {
// a %= mod;
long ret = 1;
int x = 63 - Long.numberOfLeadingZeros(n);
for (; x >= 0; x--) {
ret = ret * ret % mod;
if (n << 63 - x < 0)
ret = ret * a % mod;
}
return ret;
}
public static int lca2(int a, int b,
int[][] spar, int[] depth) {
if (depth[a] < depth[b]) {
b = ancestor(b, depth[b] - depth[a], spar);
} else if (depth[a] > depth[b]) {
a = ancestor(a, depth[a] - depth[b], spar);
}
if (a == b)
return a;
int sa = a, sb = b;
for (int low = 0, high = depth[a], t =
Integer.highestOneBit(high), k = Integer
.numberOfTrailingZeros(t); t > 0; t >>>= 1, k--) {
if ((low ^ high) >= t) {
if (spar[k][sa] != spar[k][sb]) {
low |= t;
sa = spar[k][sa];
sb = spar[k][sb];
} else {
high = low | t - 1;
}
}
}
return spar[0][sa];
}
protected static int ancestor(int a, int m, int[][] spar) {
for (int i = 0; m > 0 && a != -1; m >>>= 1, i++) {
if ((m & 1) == 1)
a = spar[i][a];
}
return a;
}
public static int[][] logstepParents(int[] par) {
int n = par.length;
int m = Integer.numberOfTrailingZeros(
Integer.highestOneBit(n - 1)) + 1;
int[][] pars = new int[m][n];
pars[0] = par;
for (int j = 1; j < m; j++) {
for (int i = 0; i < n; i++) {
pars[j][i] = pars[j - 1][i] == -1 ? -1
: pars[j - 1][pars[j - 1][i]];
}
}
return pars;
}
public static int[][] parents3(
int[][] g, int root) {
int n = g.length;
int[] par = new int[n];
Arrays.fill(par, -1);
int[] depth = new int[n];
depth[0] = 0;
int[] q = new int[n];
q[0] = root;
for (int p = 0, r = 1; p < r; p++) {
int cur = q[p];
for (int nex : g[cur]) {
if (par[cur] != nex) {
q[r++] = nex;
par[nex] = cur;
depth[nex] = depth[cur] + 1;
}
}
}
return new int[][] { par, q, depth };
}
static int[][] packU(int n, int[] from, int[] to) {
int[][] g = new int[n][];
int[] p = new int[n];
for (int f : from)
p[f]++;
for (int t : to)
p[t]++;
for (int i = 0; i < n; i++)
g[i] = new int[p[i]];
for (int i = 0; i < from.length; i++) {
g[from[i]][--p[from[i]]] = to[i];
g[to[i]][--p[to[i]]] = from[i];
}
return g;
}
public static void main(String[] args) throws Exception
{
long S = System.currentTimeMillis();
is = INPUT.isEmpty() ? System.in :
new ByteArrayInputStream(INPUT.getBytes());
out = new PrintWriter(System.out);
solve();
out.flush();
long G = System.currentTimeMillis();
tr(G-S+"ms");
}
private static boolean eof()
{
if(lenbuf == -1)return true;
int lptr = ptrbuf;
while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))
return false;
try {
is.mark(1000);
while(true){
int b = is.read();
if(b == -1){
is.reset();
return true;
}else if(!isSpaceChar(b)){
is.reset();
return false;
}
}
} catch (IOException e) {
return true;
}
}
private static byte[] inbuf = new byte[1024];
static int lenbuf = 0, ptrbuf = 0;
private static int readByte()
{
if(lenbuf == -1)throw new InputMismatchException();
if(ptrbuf >= lenbuf){
ptrbuf = 0;
try { lenbuf = is.read(inbuf); } catch (IOException e)
{ throw new InputMismatchException(); }
if(lenbuf <= 0)return -1;
}
return inbuf[ptrbuf++];
}
private static boolean isSpaceChar(int c)
{ return !(c >= 33 && c <= 126); }
private static int skip()
{ int b; while((b = readByte()) != -1 &&
isSpaceChar(b)); return b; }
private static double nd()
{ return Double.parseDouble(ns()); }
private static char nc() { return (char)skip(); }
private static String ns()
{
int b = skip();
StringBuilder sb = new StringBuilder();
while(!(isSpaceChar(b))){
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
private static char[] ns(int n)
{
char[] buf = new char[n];
int b = skip(), p = 0;
while(p < n && !(isSpaceChar(b))){
buf[p++] = (char)b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}
private static char[][] nm(int n, int m)
{
char[][] map = new char[n][];
for(int i = 0;i < n;i++)map[i] = ns(m);
return map;
}
private static int[] na(int n)
{
int[] a = new int[n];
for(int i = 0;i < n;i++)a[i] = ni();
return a;
}
private static int ni()
{
int num = 0, b;
boolean minus = false;
while((b = readByte()) != -1 && !(
(b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}
while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}
private static long nl()
{
long num = 0;
int b;
boolean minus = false;
while((b = readByte()) != -1 && !(
(b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}
while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}
private static void tr(Object... o)
{ if(INPUT.length() != 0)
System.out.println(Arrays.deepToString(o)); }
}
View More Similar Problems
Pair Sums
Given an array, we define its value to be the value obtained by following these instructions: Write down all pairs of numbers from this array. Compute the product of each pair. Find the sum of all the products. For example, for a given array, for a given array [7,2 ,-1 ,2 ] Note that ( 7 , 2 ) is listed twice, one for each occurrence of 2. Given an array of integers, find the largest v
View Solution →Lazy White Falcon
White Falcon just solved the data structure problem below using heavy-light decomposition. Can you help her find a new solution that doesn't require implementing any fancy techniques? There are 2 types of query operations that can be performed on a tree: 1 u x: Assign x as the value of node u. 2 u v: Print the sum of the node values in the unique path from node u to node v. Given a tree wi
View Solution →Ticket to Ride
Simon received the board game Ticket to Ride as a birthday present. After playing it with his friends, he decides to come up with a strategy for the game. There are n cities on the map and n - 1 road plans. Each road plan consists of the following: Two cities which can be directly connected by a road. The length of the proposed road. The entire road plan is designed in such a way that if o
View Solution →Heavy Light White Falcon
Our lazy white falcon finally decided to learn heavy-light decomposition. Her teacher gave an assignment for her to practice this new technique. Please help her by solving this problem. You are given a tree with N nodes and each node's value is initially 0. The problem asks you to operate the following two types of queries: "1 u x" assign x to the value of the node . "2 u v" print the maxim
View Solution →Number Game on a Tree
Andy and Lily love playing games with numbers and trees. Today they have a tree consisting of n nodes and n -1 edges. Each edge i has an integer weight, wi. Before the game starts, Andy chooses an unordered pair of distinct nodes, ( u , v ), and uses all the edge weights present on the unique path from node u to node v to construct a list of numbers. For example, in the diagram below, Andy
View Solution →Heavy Light 2 White Falcon
White Falcon was amazed by what she can do with heavy-light decomposition on trees. As a resut, she wants to improve her expertise on heavy-light decomposition. Her teacher gave her an another assignment which requires path updates. As always, White Falcon needs your help with the assignment. You are given a tree with N nodes and each node's value Vi is initially 0. Let's denote the path fr
View Solution →