Roy and alpha-beta trees


Problem Statement :


Roy has taken a liking to the Binary Search Trees(BST). He is interested in knowing the number of ways an array A of N  integers can be arranged to form a BST. Thus, he tries a few combinations, and notes down the numbers at the odd levels and the numbers at the even levels.

You're given two values, alpha and beta. Can you calculate the sum of Liking of all possible BST's that can be formed from an array of N integers? Liking of each BST is defined as follows

(sum of numbers on even levels * alpha) - (sum of numbers on odd levels * beta)
Note

The root element is at level 0 ( Even )
The elements smaller or equal to the parent element are present in the left subtree, elements greater than or equal to the parent element are present in the right subtree. Explained here

If the answer is no less than 10^9 + 9, output the answer %  10^9 + 9.

(If the answer is less than 0, keep adding 10^9 + 9 until the value turns non negative.).


Input Format

The first line of input file contains an integer, T, denoting the number of test cases to follow.
Each testcase comprises of 3 lines.
The first line contains N, the number of integers.
The second line contains two space separated integers, alpha and beta.
The third line contains space separated N integers_, denoting the ith integer in array A[i].


Output Format

Output T lines. Each line contains the answer to its respective test case.

Constraints

1  <=   T   <=   10
1  <=  N   <=   150
1  <=   A[i]  <=  10^9
1  <=  alpha, beta  <=  10^9



Solution :



title-img


                            Solution in C :

In    C++  :







#include <bits/stdc++.h>
 
using namespace std;


long long mod = 1000000009;

struct data {
    long long ways;
    long long sum[2];

    data operator+(const data & r) const {
        data res;
        res.ways = (ways + r.ways) % mod;
        res.sum[0] = (sum[0] + r.sum[0]) % mod;
        res.sum[1] = (sum[1] + r.sum[1]) % mod;
        return res;
    }
};

data merge(const data & l, const data & r) {
    data res;
    res.sum[0] = (l.sum[0] * r.ways + r.sum[0] * l.ways) % mod;
    res.sum[1] = (l.sum[1] * r.ways + r.sum[1] * l.ways) % mod;
    res.ways = (l.ways * r.ways) % mod;
    return res;
}

data dp[200][200];


long long a[200];

data & solve(int l, int r) {
    if (dp[l][r].ways) {
        return dp[l][r];
    }

    data & res = dp[l][r];

    if (l == r) {
        res.ways = 1;
        return res;
    }

    for (int i = l; i < r; ++i) {
        /*int j = i;
        while (j < r && a[i] == a[j]) {
            ++j;
        }*/
        //i = j - 1;
        data tmp = merge(solve(l, i), solve(i + 1, r));
        swap(tmp.sum[0], tmp.sum[1]);
        tmp.sum[0] = (tmp.sum[0] + tmp.ways * a[i]);
        res = res + tmp;
    }
    return res;
}

int main() {
 
#ifdef LOCAL
    freopen("input.txt", "r", stdin);
#else
    //freopen("2strings.in", "r", stdin);
    //freopen("2strings.out", "w", stdout);
#endif
    int t;

    scanf("%d", &t);

    while (t--) {
        int n;
        scanf("%d", &n);
        for (int i = 0; i <= n; ++i) {
            for (int j = 0; j <= n; ++j) {
                dp[i][j].ways = dp[i][j].sum[0] = dp[i][j].sum[1] = 0;
            }
        }
        long long alpha, beta;
        scanf("%lld%lld", &alpha, &beta);

        for (int i = 0; i < n; ++i) {
            scanf("%lld", a + i);
        }
        sort(a, a + n);

        data & ans = solve(0, n);
        long long v = ans.sum[0] * alpha - ans.sum[1] * beta;
        v %= mod;
        if (v < 0) {
            v += mod;
        }
        printf("%lld\n", v);
    }
    

    return 0;
}







In   Java  :







import java.io.*;
import java.util.Arrays;

public class Solution {

final static long MOD = 1000000009;

public static void solve(Input in,
 PrintWriter out) throws IOException {
final int maxn = 150;
long[] c = new long[maxn + 1];
long[][][] d = new long[2][maxn + 1][];
d[0][0] = d[1][0] = new long[0];
c[0] = 1;
for (int i = 1; i <= maxn; ++i) {
d[0][i] = new long[i];
d[1][i] = new long[i];
for (int j = 0; j < i; ++j) {
long add = c[j] * c[i - j - 1];
c[i] = (c[i] + add) % MOD;
d[0][i][j] = (d[0][i][j] + add) % MOD;
for (int t = 0; t < j; ++t) {
for (int l = 0; l < 2; ++l) {
d[l][i][t] = (d[l][i][t] + d[1
 - l][j][t] * c[i - j - 1]) % MOD;
}
}
for (int t = 0; t < i - j - 1; ++t) {
for (int l = 0; l < 2; ++l) {
d[l][i][j + 1 + t] = (d[l][i][j + 1 + t]
 + d[1 - l][i - j - 1][t] * c[j]) % MOD;
}
}
}
for (int j = 0; j < i; ++j) {
if ((d[0][i][j] + d[1][i][j]) % MOD != c[i]) {
throw new AssertionError();
}
}
}
int tests = in.nextInt();
for (int test = 0; test < tests; ++test) {
int n = in.nextInt();
long a = in.nextLong();
long b = MOD - in.nextLong();
long[] xs = new long[n];
for (int i = 0; i < n; ++i) {
xs[i] = in.nextLong();
}
Arrays.sort(xs);
long ans = 0;
for (int i = 0; i < n; ++i) {
long xa = xs[i] * a % MOD;
long xb = xs[i] * b % MOD;
ans = (
ans + d[0][n][i] * xa + d[1][n][i] * xb) % MOD;
}
out.println(ans);
}
}

public static void main(
    String[] args) throws IOException {
PrintWriter out = new PrintWriter(System.out);
solve(new Input(new BufferedReader(
    new InputStreamReader(System.in))), out);
out.close();
}

static class Input {
BufferedReader in;
StringBuilder sb = new StringBuilder();

public Input(BufferedReader in) {
this.in = in;
}

public Input(String s) {
this.in = new BufferedReader(new StringReader(s));
}

public String next() throws IOException {
sb.setLength(0);
while (true) {
int c = in.read();
if (c == -1) {
return null;
}
if (" \n\r\t".indexOf(c) == -1) {
sb.append((char)c);
break;
}
}
while (true) {
int c = in.read();
if (c == -1 || " \n\r\t".indexOf(c) != -1) {
break;
}
sb.append((char)c);
}
return sb.toString();
}

public int nextInt() throws IOException {
return Integer.parseInt(next());
}

public long nextLong() throws IOException {
return Long.parseLong(next());
}

public double nextDouble() throws IOException {
return Double.parseDouble(next());
}
}
}








In    C  :






#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

static int read_int() {
int ret;
scanf("%d", &ret);
return ret;
}

/* caller should free the memory */
static int *read_int_array(int n) {
int *ret = malloc(n * sizeof(int));
for (int i = 0; i < n; ++i) {
scanf("%d", ret + i);
}
return ret;
}

static int intcomp(const void *v1, const void *v2) {
return *(const int *)v1 - *(const int *)v2;
}

#define MOD 1000000009

struct node {
long long int odd;
long long int even;
long long int count;
};

static long long int solve(int *array, 
int size, int alpha, int beta) {
qsort(array, size, sizeof(int), intcomp);

struct node **data = 
malloc(size * sizeof(struct node *));

for (int i = 0; i < size; ++i) {
data[i] = calloc(size, sizeof(struct node));
}
for (int s = 0; s <= size; ++s) {
for (int i = 0; i < size - s; ++i) {
/* i to i + s */
for (int j = i; j <= i + s; ++j) {
long long int left_count = 1, right_count = 1;
long long int left_part_even = 0, right_part_even = 0;
long long int left_part_odd = 0, right_part_odd = 0;

/* left i to j */
if (j != i) {
assert(i + s < size);
assert(i <= j - 1);
left_part_even = data[i][j - 1].odd;
left_part_odd = data[i][j - 1].even;
assert(left_part_even < MOD && left_part_odd < MOD);
left_count = data[i][j - 1].count;
}

/* right j + 1 to i + s */
if (j != i + s) {
assert(i + s < size);
assert(j + 1 <= i + s);
right_part_even = data[j + 1][i + s].odd;
right_part_odd = data[j + 1][i + s].even;
assert(right_part_even < MOD && right_part_odd < MOD);
right_count = data[j + 1][i + s].count;
}

long long int count = left_count * right_count;
count %= MOD;
data[i][i + s].count += count;
data[i][i + s].count %= MOD;

long long int root = count * array[j];
root %= MOD;
data[i][i + s].even += root;
data[i][i + s].even %= MOD;

right_part_even *= left_count;
right_part_even %= MOD;
right_part_odd *= left_count;
right_part_odd %= MOD;

left_part_even *= right_count;
left_part_even %= MOD;
left_part_odd *= right_count;
left_part_odd %= MOD;

data[i][i + s].even += (right_part_even + left_part_even);
data[i][i + s].even %= MOD;

data[i][i + s].odd += (right_part_odd + left_part_odd);
data[i][i + s].odd %= MOD;


}

}
}
long long int even = data[0][size - 1].even;
long long int odd = data[0][size - 1].odd;
long long int val = 0;
val += even * alpha;
val %= MOD;
val -= odd * beta;
val %= MOD;
for (int i = 0; i < size; ++i) {
free(data[i]);
}
free(data);

return (val + MOD) % MOD;
}

int main(int argc, char *argv[]) {
int t = read_int();
for (int i = 0; i < t; ++i) {
int n = read_int();
int alpha = read_int(), beta = read_int();

int *array = read_int_array(n);

printf("%lld\n", solve(array, n, alpha, beta));
free(array);
}
return 0;
}









In    Python3   :







N = 151
MOD = 10 ** 9 + 9
bt = [1]
oe = [{}]

for i in range(1, N):
    c = 0
    d = {}
    for j in range(i):
        l = bt[j]
        r = bt[i - j - 1]
        for k, (e, o) in oe[j].items():
            de, do = d.get(k, [0, 0])
            d[k] = [(de + o * r) % MOD, (do + e * r) % MOD]
        for k, (e, o) in oe[i - j - 1].items():
            de, do = d.get(k + j + 1, [0, 0])
            d[k + j + 1] = [(de + o * l) % MOD, (do + e * l) % MOD]
        de, do = d.get(j, [0, 0])
        d[j] = [(de + l * r) % MOD, do]
        c += (l * r) % MOD
    bt.append(c % MOD)
    oe.append(d)

for _ in range(int(input())):
    d = oe[int(input())]
    a, b = map(int, input().split())
    l = sorted(map(int, input().split()))
    e = o = 0
    for k, (de, do) in d.items():
        e += (de * l[k]) % MOD
        o += (do * l[k]) % MOD
    e %= MOD
    o %= MOD
    print(((e * a) % MOD + (MOD - (o * b) % MOD)) % MOD)
                        








View More Similar Problems

Jesse and Cookies

Jesse loves cookies. He wants the sweetness of all his cookies to be greater than value K. To do this, Jesse repeatedly mixes two cookies with the least sweetness. He creates a special combined cookie with: sweetness Least sweet cookie 2nd least sweet cookie). He repeats this procedure until all the cookies in his collection have a sweetness > = K. You are given Jesse's cookies. Print t

View Solution →

Find the Running Median

The median of a set of integers is the midpoint value of the data set for which an equal number of integers are less than and greater than the value. To find the median, you must first sort your set of integers in non-decreasing order, then: If your set contains an odd number of elements, the median is the middle element of the sorted sample. In the sorted set { 1, 2, 3 } , 2 is the median.

View Solution →

Minimum Average Waiting Time

Tieu owns a pizza restaurant and he manages it in his own way. While in a normal restaurant, a customer is served by following the first-come, first-served rule, Tieu simply minimizes the average waiting time of his customers. So he gets to decide who is served first, regardless of how sooner or later a person comes. Different kinds of pizzas take different amounts of time to cook. Also, once h

View Solution →

Merging Communities

People connect with each other in a social network. A connection between Person I and Person J is represented as . When two persons belonging to different communities connect, the net effect is the merger of both communities which I and J belongs to. At the beginning, there are N people representing N communities. Suppose person 1 and 2 connected and later 2 and 3 connected, then ,1 , 2 and 3 w

View Solution →

Components in a graph

There are 2 * N nodes in an undirected graph, and a number of edges connecting some nodes. In each edge, the first value will be between 1 and N, inclusive. The second node will be between N + 1 and , 2 * N inclusive. Given a list of edges, determine the size of the smallest and largest connected components that have or more nodes. A node can have any number of connections. The highest node valu

View Solution →

Kundu and Tree

Kundu is true tree lover. Tree is a connected graph having N vertices and N-1 edges. Today when he got a tree, he colored each edge with one of either red(r) or black(b) color. He is interested in knowing how many triplets(a,b,c) of vertices are there , such that, there is atleast one edge having red color on all the three paths i.e. from vertex a to b, vertex b to c and vertex c to a . Note that

View Solution →