# Lazy White Falcon

### Problem Statement :

```White Falcon just solved the data structure problem below using heavy-light decomposition. Can you help her find a new solution that doesn't require implementing any fancy techniques?

There are 2 types of query operations that can be performed on a tree:

1 u x: Assign x as the value of node u.
2 u v: Print the sum of the node values in the unique path from node u to node v.
Given a tree with N nodes where each node's value is initially 0, execute Q queries.

Input Format

The first line contains 2 space-separated integers, N and Q, respectively.
The N-1 subsequent lines each contain 2 space-separated integers describing an undirected edge in the tree.
Each of the Q subsequent lines contains a query you must execute.

Constraints

1  <=  N, Q  <=  10^5
1  <=   x  <=   1000

It is guaranteed that the input describes a connected tree with N nodes.
Nodes are enumerated with 0-based indexing.

Output Format

For each type-2 query, print its integer result on a new line.```

### Solution :

```                            ```Solution in C :

In    C++  :

#include <cmath>
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;

const int N = 100010;
const int LG_N = 20;

int n, q;

int tree[2*N];
vector<int> euler;
int first[N], last[N];

int H[N], P[N][LG_N];
int val[N];

void dfs(int u, int p, int h) {
H[u] = h;
P[u][0] = p;
for(int i = 1;i < LG_N;i++) {
P[u][i] = P[P[u][i-1]][i-1];
}
first[u] = euler.size();
euler.push_back(u);
if(v == p) {
continue;
}
dfs(v, u, h+1);
}
last[u] = euler.size();
euler.push_back(u);
}
int lca(int u, int v) {
if(H[u] < H[v]) swap(u, v);
for(int i = LG_N-1;i >= 0;i--) {
if(H[P[u][i]] >= H[v]) {
u = P[u][i];
}
}
if(u == v) {
return u;
}
for(int i = LG_N-1;i >= 0;i--) {
if(P[u][i] != P[v][i]) {
u = P[u][i];
v = P[v][i];
}
}
return P[u][0];
}
void update(int idx, int val) {
while(idx < euler.size()) {
tree[idx] += val;
idx += idx & (-idx);
}
}
int query(int idx) {
int ans = 0;
while(idx > 0) {
ans += tree[idx];
idx -= idx & (-idx);
}
return ans;
}
int main() {

ios::sync_with_stdio(false);
cin >> n >> q;
for(int i = 0;i < n-1;i++) {
int u, v;
cin >> u >> v;
}

euler.resize(1, 0);
dfs(0, 0, 0);

for(int i = 0;i < q;i++) {
int type;
cin >> type;
if(type == 1) {
int u, x;
cin >> u >> x;
update(first[u], x - val[u]);
update(last[u],  val[u] - x);
val[u] = x;
}else {
int u, v;
cin >> u >> v;
int l = lca(u, v);
int ans = query(first[u]) + query(first[v]);
ans = ans - 2 * query(first[l]) + val[l];
cout << ans << "\n";
}
}
return 0;
}

In   Java :

import java.io.*;
import java.util.*;
import java.text.*;
import java.math.*;
import java.util.regex.*;

class TreeNode implements Comparable<TreeNode> {
int index;
int value;
int level = -1;    //0 is root.
TreeNode parent;
BranchContainer branch;

TreeNode(int i) {
index = i;
children = new HashSet<TreeNode>();
branch = new BranchContainer();
}

void updateValue(int v) {
int diff = v - value;
value = v;
branch.sum += diff;
}

@Override
public String toString() {
return "i=" + index + " L=" + level;
}

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + index;
return result;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
TreeNode other = (TreeNode) obj;
return index == other.index;
}

@Override
public int compareTo(TreeNode o) {
return index - o.index;
}
}

class BranchContainer {
ArrayList<TreeNode> list = new ArrayList<>();
HashSet<TreeNode> set = new HashSet<>();
int sum = 0;
boolean isTrunk = false;
}

public class Solution {
TreeNode[] nodes;
int nNodes, nQueries, treeHeight;
TreeNode root;

int getSum(final int index1, final int index2) {
final List<List<TreeNode>> path =
findPath(nodes[index1], nodes[index2]);
int ret = 0;
for (List<TreeNode> list : path) {
if (list.isEmpty()) {
continue;
}

final int segSize = list.size();
if (branchSize>2*segSize) {
for (TreeNode node : list) {
ret += node.value;
}
}
else {
tail = list.get(segSize-1);
final List<TreeNode> list1 =
final List<TreeNode> list2 =
leaf.branch.list.subList(leaf.level-tail.level+1,
branchSize);
int sum = 0;
for (TreeNode node : list1) {
sum += node.value;
}
for (TreeNode node : list2) {
sum += node.value;
}

}

ret += leaf.branch.sum - sum;
}
}

return ret;
}

List<List<TreeNode>> findPath(final TreeNode node1,
final TreeNode node2) {
List<List<TreeNode>> ret =

if (node1.branch.isTrunk ||
node1.branch.list.get(0).level==0) {
if (!findPathFixOne(node1, node2, ret)) {
System.err.println("1 Cannot find path between "
+ node1.toString() + " and " + node2.toString());
}
return ret;
}
else if (node2.branch.isTrunk ||
node2.branch.list.get(0).level==0) {
if (!findPathFixOne(node2, node1, ret)) {
System.err.println("2 Cannot find path between "
+ node2.toString() + " and " + node1.toString());
}
return ret;
}

int branches = countBrancheDist(node1, node2);
TreeNode tmp = null;
if (branches<0) {
branches = countBrancheDist(node2, node1);
if (branches<0) {
TreeNode n1 = advanceBranch(node1, 1, ret);
TreeNode n2 = advanceBranch(node2, 1, ret);
List<List<TreeNode>> tmpPath = findPath(n1, n2);
}
else if (branches==0) {
}
else {
if (!findPathFixOne(node2, tmp, ret)) {
System.err.println("3 Cannot find path between "
+ node1.toString() + " and " + tmp.toString());
}
}
}
else if (branches==0) {
}
else {
if (!findPathFixOne(node1, tmp, ret)) {
System.err.println("4 Cannot find path between "
+ node2.toString() + " and " + tmp.toString());
}
}

return ret;
}

int countBrancheDist(final TreeNode fixed,
final TreeNode node) {
int ret = 0;
boolean found = fixed.branch.set.contains(node);
if (found) {
return ret;
}

TreeNode end =
node.branch.list.get(node.branch.list.size()-1);
while (end.level>0) {
++ret;
if (fixed.branch.set.contains(end)) {
return ret;
}
end = end.branch.list.get(end.branch.list.size()-1);
}

if (fixed.branch.set.contains(end)) {
return ++ret;
}
else {
return -1;
}
}

final int n, List<List<TreeNode>> path) {
TreeNode ret = node;
for (int i = 0; i < n; ++i) {
int size = ret.branch.list.size()-1;
level-ret.level, size));
ret = ret.branch.list.get(size);
}

return ret;
}

boolean findPathFixOne(final TreeNode fixed,
TreeNode node,
List<List<TreeNode>> path) {
while (node.level>0 &&
!fixed.branch.set.contains(node)) {
final int end = node.branch.list.size() - 1;
list.get(0).level - node.level, end));
node = node.branch.list.get(end);
}

if(!fixed.branch.set.contains(node)) {
return false;
}

return true;
}

final TreeNode node2, List<List<TreeNode>> path) {
int leafLevel = node1.branch.list.get(0).level;
int level1 = node1.level,
level2 = node2.level;
if (level1<level2) {
int tmpI = level1;
level1 = level2;
level2 = tmpI;
}

leafLevel-level2+1));
}

void organizeTree() {
root = null;
for (int i = 0; i < nNodes; ++i) {
final TreeNode node = nodes[i];
root = node;
}
}

setChildren();
enumerateBranches();

return;
}

void setChildren() {
int level = 0;
root.level = level;
Map<TreeNode, Set<TreeNode>> pcMap =
new HashMap<>();
while (!pcMap.isEmpty()) {
Map<TreeNode, Set<TreeNode>> newMap =
new HashMap<>();
for (Map.Entry<TreeNode,
Set<TreeNode>> entry : pcMap.entrySet()) {
final TreeNode parent = entry.getKey();
final Set<TreeNode> list = entry.getValue();
parent.level = level;
if (parent.parent!=null) {
parent.children.remove(parent.parent);
}

for (TreeNode node : parent.children) {
node.parent = parent;
}
}

++level;
pcMap = newMap;
}

treeHeight = level;
}

void enumerateBranches() {
boolean foundTrunk = false;
for (int i = 0; i < nNodes; ++i) {
final TreeNode node = nodes[i];
if (!node.children.isEmpty()) {
continue;
}

TreeNode tmpNode = node.parent;
while (tmpNode!=null) {
if (tmpNode.branch.list.isEmpty()) {
tmpNode.branch = node.branch;
tmpNode = tmpNode.parent;
}
else {
break;
}
}

if (!foundTrunk && tmpNode==null) {
foundTrunk = true;
node.branch.isTrunk = true;
}
}

return;
}

public static void main(String[] args) {
try {
long t1 = System.currentTimeMillis();

Solution falcon = new Solution();

int index1 = 0,
index2 = line.indexOf(' ', index1);
falcon.nNodes =
Integer.parseInt(line.substring(index1, index2));
index1 = index2+1;
index2 = line.length();
falcon.nQueries =
Integer.parseInt(line.substring(index1, index2));

falcon.nodes = new TreeNode[falcon.nNodes];
PrintWriter out =
new PrintWriter(new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(
java.io.FileDescriptor.out), "UTF-8"), 512));

for (int i = 0; i < falcon.nNodes-1; ++i) {
index1 = 0;
index2 = line.indexOf(' ', index1);
final int n1 = Integer.parseInt(
line.substring(index1, index2));
index1 = index2+1;
index2 = line.length();
final int n2 = Integer.parseInt(
line.substring(index1, index2));

TreeNode node1, node2;

if (falcon.nodes[n1]!=null) {
node1 = falcon.nodes[n1];
}
else {
node1 = new TreeNode(n1);
falcon.nodes[n1] = node1;
}

if (falcon.nodes[n2]!=null) {
node2 = falcon.nodes[n2];
}
else {
node2 = new TreeNode(n2);
falcon.nodes[n2] = node2;
}

}

falcon.organizeTree();

for (int i = 0; i < falcon.nQueries; ++i)
{
index1 = 0;
index2 = line.indexOf(' ', index1);
final int q = Integer.parseInt(
line.substring(index1, index2));
index1 = index2+1;
index2 = line.indexOf(' ', index1);
final Integer u = new Integer(
line.substring(index1, index2));
index1 = index2+1;
index2 = line.length();
final Integer v = new Integer(
line.substring(index1, index2));

switch(q) {
case 1: falcon.nodes[u].updateValue(v);
break;
case 2:
out.println(falcon.getSum(u, v));
break;
default:    System.err.println("Invalid query " + q);
}
}
out.flush();

}
catch (Exception e) {
e.printStackTrace( System.err );
}
}
}

In   C   :

#include <stdio.h>
#include <stdlib.h>
typedef struct _lnode{
int x;
int w;
struct _lnode *next;
} lnode;
typedef struct _tree{
int sum;
} tree;
void insert_edge(int x,int y,int w);
void dfs0(int u);
void dfs1(int u,int c);
void preprocess();
int lca(int a,int b);
int sum(int v,int tl,
int tr,int l,int r,tree *t);
void update(int v,int tl,
int tr,int pos,int new_val,tree *t);
int min(int x,int y);
int max(int x,int y);
int solve(int x,int ancestor);
int N,cn,level[100000],DP[18][100000],
subtree_size[100000],special[100000],
node_chain[100000],node_idx[100000],
lnode *table[100000]={0};
tree *chain[100000];

int main(){
int Q,x,y,i;
scanf("%d%d",&N,&Q);
for(i=0;i<N-1;i++){
scanf("%d%d",&x,&y);
insert_edge(x,y,1);
}
preprocess();
while(Q--){
scanf("%d",&x);
switch(x){
case 1:
scanf("%d%d",&x,&y);
update(1,0,chain_len[node_chain[x]]
-1,node_idx[x],y,chain[node_chain[x]]);
break;
default:
scanf("%d%d",&x,&y);
i=lca(x,y);
printf("%d\n",
solve(x,i)+solve(y,i)-
sum(1,0,chain_len[node_chain[i]]
-1,node_idx[i],node_idx[i],chain[node_chain[i]]));
}
}
return 0;
}
void insert_edge(int x,int y,int w){
lnode *t=malloc(sizeof(lnode));
t->x=y;
t->w=w;
t->next=table[x];
table[x]=t;
t=malloc(sizeof(lnode));
t->x=x;
t->w=w;
t->next=table[y];
table[y]=t;
return;
}
void dfs0(int u){
lnode *x;
subtree_size[u]=1;
special[u]=-1;
for(x=table[u];x;x=x->next)
if(x->x!=DP[0][u]){
DP[0][x->x]=u;
level[x->x]=level[u]+1;
dfs0(x->x);
subtree_size[u]+=subtree_size[x->x];
if(special[u]==-1 ||
subtree_size[x->x]>subtree_size[special[u]])
special[u]=x->x;
}
return;
}
void dfs1(int u,int c){
lnode *x;
node_chain[u]=c;
node_idx[u]=chain_len[c]++;
for(x=table[u];x;x=x->next)
if(x->x!=DP[0][u])
if(x->x==special[u])
dfs1(x->x,c);
else{
dfs1(x->x,cn++);
}
return;
}
void preprocess(){
int i,j;
level[0]=0;
DP[0][0]=0;
dfs0(0);
for(i=1;i<18;i++)
for(j=0;j<N;j++)
DP[i][j] = DP[i-1][DP[i-1][j]];
cn=1;
dfs1(0,0);
for(i=0;i<cn;i++)
chain[i]=(tree*)malloc(
4*chain_len[i]*sizeof(tree));
for(i=0;i<N;i++)
update(1,0,chain_len[node_chain[i]]-1,
node_idx[i],0,chain[node_chain[i]]);
return;
}
int lca(int a,int b){
int i;
if(level[a]>level[b]){
i=a;
a=b;
b=i;
}
int d = level[b]-level[a];
for(i=0;i<18;i++)
if(d&(1<<i))
b=DP[i][b];
if(a==b)return a;
for(i=17;i>=0;i--)
if(DP[i][a]!=DP[i][b])
a=DP[i][a],b=DP[i][b];
return DP[0][a];
}
int sum(int v,int tl,int tr,int l,
int r,tree *t){
if(l>r)
return 0;
if(l==tl && r==tr)
return t[v].sum;
int tm=(tl+tr)/2;
return sum(v*2,tl,tm,l,min(r,tm),t)+
sum(v*2+1,tm+1,tr,max(l,tm+1),r,t);
}
void update(int v,int tl,int tr,
int pos,int new_val,tree *t){
if(tl==tr)
t[v].sum=new_val;
else{
int tm=(tl+tr)/2;
if(pos<=tm)
update(v*2,tl,tm,pos,new_val,t);
else
update(v*2+1,tm+1,tr,pos,new_val,t);
t[v].sum=t[v*2].sum+t[v*2+1].sum;
}
}
int min(int x,int y){
return (x<y)?x:y;
}
int max(int x,int y){
return (x>y)?x:y;
}
int solve(int x,int ancestor){
int ans=0;
while(node_chain[x]!=node_chain[ancestor]){
ans+=sum(1,0,chain_len[node_chain[x]]-1,
0,node_idx[x],chain[node_chain[x]]);
}
ans+=sum(1,0,chain_len[node_chain[x]]-1,
node_idx[ancestor],node_idx[x],
chain[node_chain[x]]);
return ans;
}

In   Python3  :

class heavy_light_node:
def __init__(self, size):
self.parent = None
self.pos = -1
self.weight = [0] * size
self.fenwick = [0] * size
def set_weight(self, i, x):
d = x - self.weight[i]
self.weight[i] = x
N = len(self.weight)
while i < N:
self.fenwick[i] += d
i |= i + 1
def sum_weight(self, i):
if i < 0: return 0
x = self.fenwick[i]
i &= i + 1
while i:
x += self.fenwick[i-1]
i &= i - 1
return x
def build_tree(i, edges, location):
children = []
members = [i]
ed = edges[i]
while ed:
for j in range(1,len(ed)):
child = build_tree(ed[j], edges, location)
child.pos = len(members) - 1
children.append(child)
i = ed[0]
members.append(i)
ed = edges[i]
node = heavy_light_node(len(members))
for child in children:
child.parent = node
for j in range(len(members)):
location[members[j]] = (node, j)
return node
edges = [[] for i in range(N)]
for i in range(N-1):
x, y = map(int, input().split())
edges[x].append(y)
edges[y].append(x)
size = [0] * N
active = [0]
while active:
i = active[-1]
if size[i] == 0:
size[i] = 1
for j in edges[i]:
edges[j].remove(i)
active.append(j)
else:
active.pop()
edges[i].sort(key=lambda j: -size[j])
size[i] = 1 + sum(size[j] for j in edges[i])
location = [None] * N
build_tree(0, edges, location)
return location
def root_path(i, location):
loc = location[i]
path = [ loc ]
loc = loc[0]
while loc.parent != None:
path.append((loc.parent, loc.pos))
loc = loc.parent
path.reverse()
return path
def max_weight(x, y):
px = root_path(x, location)
py = root_path(y, location)
m = 1
stop = min(len(px), len(py))
while m < stop and px[m][0] == py[m][0]: m += 1
loc, a = px[m-1]
b = py[m-1][1]
if a > b: a, b = b, a
w = loc.sum_weight(b) - loc.sum_weight(a-1)
for j in range(m, len(px)):
loc, i = px[j]
w += loc.sum_weight(i)
for j in range(m, len(py)):
loc, i = py[j]
w += loc.sum_weight(i)
return w
N, Q = map(int, input().split())
for i in range(Q):
t, x, y = map(int, input().split())
if t == 1:
loc, i = location[x]
loc.set_weight(i, y)
elif t == 2:
print(max_weight(x, y))```
```

