# Roy and alpha-beta trees

### Problem Statement :

```Roy has taken a liking to the Binary Search Trees(BST). He is interested in knowing the number of ways an array A of N  integers can be arranged to form a BST. Thus, he tries a few combinations, and notes down the numbers at the odd levels and the numbers at the even levels.

You're given two values, alpha and beta. Can you calculate the sum of Liking of all possible BST's that can be formed from an array of N integers? Liking of each BST is defined as follows

(sum of numbers on even levels * alpha) - (sum of numbers on odd levels * beta)
Note

The root element is at level 0 ( Even )
The elements smaller or equal to the parent element are present in the left subtree, elements greater than or equal to the parent element are present in the right subtree. Explained here

If the answer is no less than 10^9 + 9, output the answer %  10^9 + 9.

(If the answer is less than 0, keep adding 10^9 + 9 until the value turns non negative.).

Input Format

The first line of input file contains an integer, T, denoting the number of test cases to follow.
Each testcase comprises of 3 lines.
The first line contains N, the number of integers.
The second line contains two space separated integers, alpha and beta.
The third line contains space separated N integers_, denoting the ith integer in array A[i].

Output Format

Output T lines. Each line contains the answer to its respective test case.

Constraints

1  <=   T   <=   10
1  <=  N   <=   150
1  <=   A[i]  <=  10^9
1  <=  alpha, beta  <=  10^9```

### Solution :

```                            ```Solution in C :

In    C++  :

#include <bits/stdc++.h>

using namespace std;

long long mod = 1000000009;

struct data {
long long ways;
long long sum[2];

data operator+(const data & r) const {
data res;
res.ways = (ways + r.ways) % mod;
res.sum[0] = (sum[0] + r.sum[0]) % mod;
res.sum[1] = (sum[1] + r.sum[1]) % mod;
return res;
}
};

data merge(const data & l, const data & r) {
data res;
res.sum[0] = (l.sum[0] * r.ways + r.sum[0] * l.ways) % mod;
res.sum[1] = (l.sum[1] * r.ways + r.sum[1] * l.ways) % mod;
res.ways = (l.ways * r.ways) % mod;
return res;
}

data dp[200][200];

long long a[200];

data & solve(int l, int r) {
if (dp[l][r].ways) {
return dp[l][r];
}

data & res = dp[l][r];

if (l == r) {
res.ways = 1;
return res;
}

for (int i = l; i < r; ++i) {
/*int j = i;
while (j < r && a[i] == a[j]) {
++j;
}*/
//i = j - 1;
data tmp = merge(solve(l, i), solve(i + 1, r));
swap(tmp.sum[0], tmp.sum[1]);
tmp.sum[0] = (tmp.sum[0] + tmp.ways * a[i]);
res = res + tmp;
}
return res;
}

int main() {

#ifdef LOCAL
freopen("input.txt", "r", stdin);
#else
//freopen("2strings.in", "r", stdin);
//freopen("2strings.out", "w", stdout);
#endif
int t;

scanf("%d", &t);

while (t--) {
int n;
scanf("%d", &n);
for (int i = 0; i <= n; ++i) {
for (int j = 0; j <= n; ++j) {
dp[i][j].ways = dp[i][j].sum[0] = dp[i][j].sum[1] = 0;
}
}
long long alpha, beta;
scanf("%lld%lld", &alpha, &beta);

for (int i = 0; i < n; ++i) {
scanf("%lld", a + i);
}
sort(a, a + n);

data & ans = solve(0, n);
long long v = ans.sum[0] * alpha - ans.sum[1] * beta;
v %= mod;
if (v < 0) {
v += mod;
}
printf("%lld\n", v);
}

return 0;
}

In   Java  :

import java.io.*;
import java.util.Arrays;

public class Solution {

final static long MOD = 1000000009;

public static void solve(Input in,
PrintWriter out) throws IOException {
final int maxn = 150;
long[] c = new long[maxn + 1];
long[][][] d = new long[2][maxn + 1][];
d[0][0] = d[1][0] = new long[0];
c[0] = 1;
for (int i = 1; i <= maxn; ++i) {
d[0][i] = new long[i];
d[1][i] = new long[i];
for (int j = 0; j < i; ++j) {
long add = c[j] * c[i - j - 1];
c[i] = (c[i] + add) % MOD;
d[0][i][j] = (d[0][i][j] + add) % MOD;
for (int t = 0; t < j; ++t) {
for (int l = 0; l < 2; ++l) {
d[l][i][t] = (d[l][i][t] + d[1
- l][j][t] * c[i - j - 1]) % MOD;
}
}
for (int t = 0; t < i - j - 1; ++t) {
for (int l = 0; l < 2; ++l) {
d[l][i][j + 1 + t] = (d[l][i][j + 1 + t]
+ d[1 - l][i - j - 1][t] * c[j]) % MOD;
}
}
}
for (int j = 0; j < i; ++j) {
if ((d[0][i][j] + d[1][i][j]) % MOD != c[i]) {
throw new AssertionError();
}
}
}
int tests = in.nextInt();
for (int test = 0; test < tests; ++test) {
int n = in.nextInt();
long a = in.nextLong();
long b = MOD - in.nextLong();
long[] xs = new long[n];
for (int i = 0; i < n; ++i) {
xs[i] = in.nextLong();
}
Arrays.sort(xs);
long ans = 0;
for (int i = 0; i < n; ++i) {
long xa = xs[i] * a % MOD;
long xb = xs[i] * b % MOD;
ans = (
ans + d[0][n][i] * xa + d[1][n][i] * xb) % MOD;
}
out.println(ans);
}
}

public static void main(
String[] args) throws IOException {
PrintWriter out = new PrintWriter(System.out);
out.close();
}

static class Input {
StringBuilder sb = new StringBuilder();

this.in = in;
}

public Input(String s) {
}

public String next() throws IOException {
sb.setLength(0);
while (true) {
if (c == -1) {
return null;
}
if (" \n\r\t".indexOf(c) == -1) {
sb.append((char)c);
break;
}
}
while (true) {
if (c == -1 || " \n\r\t".indexOf(c) != -1) {
break;
}
sb.append((char)c);
}
return sb.toString();
}

public int nextInt() throws IOException {
return Integer.parseInt(next());
}

public long nextLong() throws IOException {
return Long.parseLong(next());
}

public double nextDouble() throws IOException {
return Double.parseDouble(next());
}
}
}

In    C  :

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

int ret;
scanf("%d", &ret);
return ret;
}

/* caller should free the memory */
int *ret = malloc(n * sizeof(int));
for (int i = 0; i < n; ++i) {
scanf("%d", ret + i);
}
return ret;
}

static int intcomp(const void *v1, const void *v2) {
return *(const int *)v1 - *(const int *)v2;
}

#define MOD 1000000009

struct node {
long long int odd;
long long int even;
long long int count;
};

static long long int solve(int *array,
int size, int alpha, int beta) {
qsort(array, size, sizeof(int), intcomp);

struct node **data =
malloc(size * sizeof(struct node *));

for (int i = 0; i < size; ++i) {
data[i] = calloc(size, sizeof(struct node));
}
for (int s = 0; s <= size; ++s) {
for (int i = 0; i < size - s; ++i) {
/* i to i + s */
for (int j = i; j <= i + s; ++j) {
long long int left_count = 1, right_count = 1;
long long int left_part_even = 0, right_part_even = 0;
long long int left_part_odd = 0, right_part_odd = 0;

/* left i to j */
if (j != i) {
assert(i + s < size);
assert(i <= j - 1);
left_part_even = data[i][j - 1].odd;
left_part_odd = data[i][j - 1].even;
assert(left_part_even < MOD && left_part_odd < MOD);
left_count = data[i][j - 1].count;
}

/* right j + 1 to i + s */
if (j != i + s) {
assert(i + s < size);
assert(j + 1 <= i + s);
right_part_even = data[j + 1][i + s].odd;
right_part_odd = data[j + 1][i + s].even;
assert(right_part_even < MOD && right_part_odd < MOD);
right_count = data[j + 1][i + s].count;
}

long long int count = left_count * right_count;
count %= MOD;
data[i][i + s].count += count;
data[i][i + s].count %= MOD;

long long int root = count * array[j];
root %= MOD;
data[i][i + s].even += root;
data[i][i + s].even %= MOD;

right_part_even *= left_count;
right_part_even %= MOD;
right_part_odd *= left_count;
right_part_odd %= MOD;

left_part_even *= right_count;
left_part_even %= MOD;
left_part_odd *= right_count;
left_part_odd %= MOD;

data[i][i + s].even += (right_part_even + left_part_even);
data[i][i + s].even %= MOD;

data[i][i + s].odd += (right_part_odd + left_part_odd);
data[i][i + s].odd %= MOD;

}

}
}
long long int even = data[0][size - 1].even;
long long int odd = data[0][size - 1].odd;
long long int val = 0;
val += even * alpha;
val %= MOD;
val -= odd * beta;
val %= MOD;
for (int i = 0; i < size; ++i) {
free(data[i]);
}
free(data);

return (val + MOD) % MOD;
}

int main(int argc, char *argv[]) {
for (int i = 0; i < t; ++i) {

printf("%lld\n", solve(array, n, alpha, beta));
free(array);
}
return 0;
}

In    Python3   :

N = 151
MOD = 10 ** 9 + 9
bt = [1]
oe = [{}]

for i in range(1, N):
c = 0
d = {}
for j in range(i):
l = bt[j]
r = bt[i - j - 1]
for k, (e, o) in oe[j].items():
de, do = d.get(k, [0, 0])
d[k] = [(de + o * r) % MOD, (do + e * r) % MOD]
for k, (e, o) in oe[i - j - 1].items():
de, do = d.get(k + j + 1, [0, 0])
d[k + j + 1] = [(de + o * l) % MOD, (do + e * l) % MOD]
de, do = d.get(j, [0, 0])
d[j] = [(de + l * r) % MOD, do]
c += (l * r) % MOD
bt.append(c % MOD)
oe.append(d)

for _ in range(int(input())):
d = oe[int(input())]
a, b = map(int, input().split())
l = sorted(map(int, input().split()))
e = o = 0
for k, (de, do) in d.items():
e += (de * l[k]) % MOD
o += (do * l[k]) % MOD
e %= MOD
o %= MOD
print(((e * a) % MOD + (MOD - (o * b) % MOD)) % MOD)```
```

