# Subsequence Weighting

### Problem Statement :

```A subsequence of a sequence is a sequence which is obtained by deleting zero or more elements from the sequence.

You are given a sequence A in which every element is a pair of integers  i.e  A = [(a1, w1), (a2, w2),..., (aN, wN)].

For a subseqence B = [(b1, v1), (b2, v2), ...., (bM, vM)] of the given sequence :

We call it increasing if for every i (1 <= i < M ) , bi < bi+1.
Weight(B) = v1 + v2 + ... + vM.
Task:
Given a sequence, output the maximum weight formed by an increasing subsequence.

Input:
The first line of input contains a single integer T. T test-cases follow. The first line of each test-case contains an integer N. The next line contains a1, a2 ,... , aN separated by a single space. The next line contains w1, w2, ..., wN separated by a single space.

Output:
For each test-case output a single integer: The maximum weight of increasing subsequences of the given sequence.

Constraints:
1 <= T <= 5
1 <= N <= 150000
1 <= ai <= 109, where i ∈ [1..N]
1 <= wi <= 109, where i ∈ [1..N]

Sample Input:

2
4
1 2 3 4
10 20 30 40
8
1 2 3 4 1 2 3 4
10 20 30 40 15 15 15 50
Sample Output:

100
110```

### Solution :

```                            ```Solution in C :

In  C++ :

#include <vector>
#include <list>
#include <map>
#include <set>
#include <deque>
#include <queue>
#include <stack>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <cstring>
#include <climits>

using namespace std;

#define GI ({int new_input;scanf("%d",&new_input);new_input;})
typedef unsigned long long ll;

ll Tree[800000];
void updateTree(int b, int e, int p, ll  val, int idx=1) {
if(p < b || p > e) return ;
if(p == b && p == e){
Tree[idx] = max(Tree[idx],val);
return ;
}
int mid = (b+e)/2;
int lt = (idx<<1);
int rt = ((idx<<1)+1);
updateTree(b, mid, p, val, lt);
updateTree(mid+1, e, p, val, rt);
Tree[idx] = max(Tree[lt], Tree[rt]);
return ;
}
ll query(int b,int e,int start,int end,int node){
if(e<start || b>end)return 0;
if(b<=start && e>=end)return Tree[node];
int mid=(start+end)>>1;
return max(query(b,e,start,mid,node*2),query(b,e,mid+1,end,node*2+1));
}
ll input[200000];
ll w[200000];
map<ll,int>m;
set<ll>s;
int main() {
int t=GI;
while(t--){
m.clear();s.clear();
s.empty();
memset(Tree,0,sizeof Tree);
int n=GI;
for(int i=0;i<n;i++){
scanf("%lld",&input[i]);
s.insert(input[i]);
}
for(int i=0;i<n;i++){
scanf("%lld",&w[i]);
}
int in=1;
set<ll>::iterator it;
for(it=s.begin();it!=s.end();it++){
m[*it]=in;
in++;
}in--;
ll ans=0;
for(int i=0;i<n;i++){
int mapped=m[input[i]];
if(mapped==1){
updateTree(1,in,mapped,w[i],1);
ans=max(ans,w[i]);
}
else{
ll get=query(1,mapped-1,1,in,1);
ans=max(ans,get+w[i]);
updateTree(1,in,mapped,w[i]+get,1);
}
}
cout<<ans<<endl;
}
return  0;
}

In Java :

import java.util.Map.Entry;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
import java.util.SortedMap;
import java.util.TreeMap;

public class Solution {

public static void main(String[] args) {

Scanner sc = new Scanner(System.in);
int  nProb = sc.nextInt();
for(int k = 1; k <= nProb; ++k) {
int n = sc.nextInt();
int[] a = new int[n];
int[] w = new int[n];
for(int i = 0; i < n; ++i)
a[i] = sc.nextInt();
for(int i = 0; i < n; ++i)
w[i] = sc.nextInt();
long bestW = solve(a, w);
System.out.println(bestW);
}
}

private static long solve(int[] a, int[] w) {
int n = a.length;
long best = 0;
TreeMap<Integer, Long> map = new TreeMap<Integer, Long>();
for(int k = 0; k < n; ++k) {
Entry<Integer, Long> e = map.lowerEntry(a[k]);
long b = (e == null ? 0 : e.getValue()) + w[k];
SortedMap<Integer, Long> tail = map.tailMap(a[k]);
List<Integer> del = new ArrayList<Integer>();
for(Entry<Integer, Long> x : tail.entrySet()) {
if(x.getValue().longValue() > b)
break;
del.add(x.getKey());
}
for(Integer i : del) {
map.remove(i);
}
if(!map.containsKey(a[k]))
map.put(a[k], b);
if(best < b)
best = b;
}
return best;
}
}

In C :

#include <stdio.h>
#include <stdlib.h>
#include <time.h>

typedef struct treap {
int x, p;
long long y;
struct treap *l, *r;
}* Treap;

Treap td = NULL;

Treap newTreap(int x, long long y) {
Treap t;
if (td) {
t = td;
td = td->r;
}
else
t = (Treap) malloc(sizeof(struct treap));
t->x = x;
t->y = y;
t->p = rand();
t->l = t->r = NULL;
return t;
}

void dump(Treap t) {
if (t) {
dump(t->l);
dump(t->r);
t->r = td;
td = t;
}
}

Treap merge(Treap l, Treap r) {
if (!l)
return r;
if (!r)
return l;
if (l->p > r->p) {
l->r = merge(l->r, r);
return l;
}
r->l = merge(l, r->l);
return r;
}

void split(Treap t, Treap *l, Treap *r, long long v, int d) {
if (!t)
*l = *r = NULL;
else if (d ? (t->x < v) : (t->y <= v)) {
split(t->r, &t->r, r, v, d);
*l = t;
}
else {
split(t->l, l, &t->l, v, d);
*r = t;
}
}

Treap rightmost(Treap t) {
if (t) {
while (t->r)
t = t->r;
}
return t;
}

Treap leftmost(Treap t) {
if (t) {
while (t->l)
t = t->l;
}
return t;
}

long long solve() {
int n, i;
long long v;
scanf("%d", &n);
int a[n], w[n];
for (i = -1; ++i < n; scanf("%d", a + i));
for (i = -1; ++i < n; scanf("%d", w + i));
Treap r = newTreap(0, 0), l, m;
for (i = -1; ++i < n;) {
split(r, &l, &r, a[i], 1);
m = rightmost(l);
v = w[i] + m->y;
split(r, &m, &r, v, 0);
if (m) {
dump(m);
m = newTreap(a[i], v);
}
else {
m = leftmost(r);
if (!m || m->x > a[i])
m = newTreap(a[i], v);
else
m = NULL;
}
l = merge(l, m);
r = merge(l, r);
}
v = rightmost(r)->y;
dump(r);
return v;
}

int main() {
srand(time(NULL));
int t;
scanf("%d", &t);
while (t--)
printf("%lld\n", solve());
return 0;
}

In Python3 :

import os
import sys
import bisect
# Complete the solve function below.
def solve(a, w):
b = [[0,0],[10000000000,10000000000]]
for i in range(len(a)):
g = [a[i],w[i]]
bisect.insort(b,g)
ind = b.index(g)
if b[ind+1][0] != b[ind][0] and b[ind-1][0] != b[ind][0]:
b[ind][1]+=b[ind-1][1]
for j in range(ind+1,len(b)):
if b[j][1] >b[ind][1]:
break
b = b[:ind+1] + b[j:]
elif b[ind+1][0] == b[ind][0]:
b[ind][1]+=b[ind-1][1]
if b[ind+1][1]>=b[ind][1]:
b.remove(b[ind])
else:
b.remove(b[ind+1])
for j in range(ind+1,len(b)):
if b[j][1]>b[ind][1]:
break
b = b[: ind+1] + b[j: ]
elif b[ind-1][0] ==b[ind][0]:
b[ind][1] += b[ind-2][1]
if b[ind-1][1] >= b[ind][1]:
b.remove(b[ind])
else:
for j in range(ind+1,len(b)):
if b[j][1]>b[ind][1]:
break
b = b[: ind+1] + b[j: ]
b.remove(b[ind-1])
return b[-2][1]
if __name__ == '__main__':
fptr = open(os.environ['OUTPUT_PATH'], 'w')
t = int(input())
for t_itr in range(t):
n = int(input())
a = list(map(int, input().rstrip().split()))
w = list(map(int, input().rstrip().split()))
result = solve(a, w)
fptr.write(str(result) + '\n')
fptr.close()```
```

