Counting the Ways


Problem Statement :


Little Walter likes playing with his toy scales. He has N types of weights. The ith weight type has weight ai. There are infinitely many weights of each type.

Recently, Walter defined a function, F(X), denoting the number of different ways to combine several weights so their total weight is equal to X. Ways are considered to be different if there is a type which has a different number of weights used in these two ways.

For example, if there are  types of weights with corresonding weights 1, 1, and 2, then there are 4 ways to get a total weight of 2:

1.Use 2 weights of type 1.
2.Use 2 weights of type 2.
3.Use 1 weight of type 1 and 1 weight of type 2.
4.Use 1 weight of type 3.
Given N, L, R, and a1,a2,...,aN, can you find the value of F(L)+F(L+1)+...+F(R)?

Input Format

The first line contains a single integer, N, denoting the number of types of weights.
The second line contains N space-separated integers describing the values of a1,a2,...,aN, respectively
The third line contains two space-separated integers denoting the respective values of L and R.

Constraints
1 <= N <= 10
0 < ai <= 10^5
a1*a2*...*aN <= 10^5
1 <= L <= R <=10^17



Solution :



title-img


                            Solution in C :

In C++ :





#include <bits/stdc++.h>
using namespace std;

const long long MOD = 1e9 + 7;

int n;
int a[10];
long long L, R;

const int N = 202000;
int dp0[N];
int dp1[N];

inline void add(int &x, int y) {
	x += y;
	if (x >= MOD) x -= MOD;
}

long long solve(long long v) {
	bitset<62> s(v);
	memset(dp0, 0, sizeof(dp0));
	dp0[0] = 1;

	for (int k = 0; k < 62; k++) {
		for (int i = 0; i < n; i++) {
			for (int j = N - a[i] - 1; j >= 0; j--) {
				add(dp0[j + a[i]], dp0[j]);
			}
		}
		if (s[k]) {
			add(dp0[0], dp0[1]);
			for (int i = 1; i < N - 1; i++) {
				dp0[i] = dp0[i + 1];
			}
		}
		memset(dp1, 0, sizeof(dp1));
		for (int i = 0; i < N; i++) {
			add(dp1[(i + 1) / 2], dp0[i]); 
		}
		swap(dp0, dp1);
	}

	return dp0[0];
}

int main() {
	cin >> n;
	for (int i = 0; i < n; i++) cin >> a[i];
	cin >> L >> R;

	int ans = solve(R) - solve(L - 1);
	if (ans < 0) ans += MOD;
	cout << ans << endl;
}









In Java :





import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class F {
    InputStream is;
    PrintWriter out;
    String INPUT = "";
    int mod = 1000000007;
    
    void solve()
    {
        int n = ni();
        int[] a = na(n);
        long L = nl(), R = nl();
        int Z = 1200000;
        long[] dp = new long[Z+1];
        Arrays.fill(dp, 1);
        long pe = 1;
        for(int v : a){
            pe *= v;
            for(int i = 0;i+v <= Z;i++){
                dp[i+v] += dp[i];
                if(dp[i+v] >= mod)dp[i+v] -= mod;
            }
        }
        int[][] fif = enumFIF(30, mod);
        long ret = 0;
        {
            long[] y = new long[12];
            for(int i = 0;i < 12;i++){
                y[i] = dp[(int)(R%pe+i*pe)];
            }
            ret += guessDirectly(mod, R/pe, fif, y);
        }
        {
            long[] y = new long[12];
            for(int i = 0;i < 12;i++){
                y[i] = dp[(int)((L-1)%pe+i*pe)];
            }
            ret -= guessDirectly(mod, (L-1)/pe, fif, y);
        }
        if(ret < 0)ret += mod;
        out.println(ret);
    }
    
    public static int[][] enumFIF(int n, int mod) {
        int[] f = new int[n + 1];
        int[] invf = new int[n + 1];
        f[0] = 1;
        for (int i = 1; i <= n; i++) {
            f[i] = (int) ((long) f[i - 1] * i % mod);
        }
        long a = f[n];
        long b = mod;
        long p = 1, q = 0;
        while (b > 0) {
            long c = a / b;
            long d;
            d = a;
            a = b;
            b = d % b;
            d = p;
            p = q;
            q = d - c * q;
        }
        invf[n] = (int) (p < 0 ? p + mod : p);
        for (int i = n - 1; i >= 0; i--) {
            invf[i] = (int) ((long) invf[i + 1] * (i + 1) % mod);
        }
        return new int[][] { f, invf };
    }
    
    public static long guessDirectly(long mod, long x, int[][] fif, long... y)
     {
        int n = y.length;
        if(0 <= x && x < n){
            return y[(int)x];
        }else if(x % mod - (n-1) <= 0){
            long mul = 1;
            for(int i = 0;i < n;i++){
                if((x-i)%mod == 0)continue;
                mul = mul * ((x-i)%mod) % mod;
            }
            long s = 0;
            long sig = 1;
            long big = 8L*mod*mod;
            for(int i = n-1;i >= 0;i--){
                if((x-i)%mod == 0){
                    s += fif[1][i] % mod * fif[1][n-1-i] % mod * y[i] * sig;
                    if(s >= big)s -= big;
                    if(s <= -big)s += big;
                }
                sig = -sig;
            }
            s %= mod;
            if(s < 0)s += mod;
            s = s * mul % mod;
            return s;
        }else{
            long mul = 1;
            for(int i = 0;i < n;i++){
                mul = mul * ((x-i)%mod)%mod;
            }
            long s = 0;
            long sig = 1;
            long big = 8L*mod*mod;
            for(int i = n-1;i >= 0;i--){
                s += invl(x-i, mod) * fif[1][i] % mod * fif[1][n-1-i] % mod * y[i] * sig;
                if(s >= big)s -= big;
                if(s <= -big)s += big;
                sig = -sig;
            }
            s %= mod;
            if(s < 0)s += mod;
            s = s * mul % mod;
            return s;
        }
    }

    public static long invl(long a, long mod) {
        long b = mod;
        long p = 1, q = 0;
        while (b > 0) {
            long c = a / b;
            long d;
            d = a;
            a = b;
            b = d % b;
            d = p;
            p = q;
            q = d - c * q;
        }
        return p < 0 ? p + mod : p;
    }
    void run() throws Exception
    {
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        long s = System.currentTimeMillis();
        solve();
        out.flush();
        if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
    }
    
    public static void main(String[] args) throws Exception {
     new F().run(); }
    
    private byte[] inbuf = new byte[1024];
    private int lenbuf = 0, ptrbuf = 0;
    
    private int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { 
            throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private boolean isSpaceChar(int c) { 
return !(c >= 33 && c <= 126); }
    private int skip() 
{ int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private double nd() { 
return Double.parseDouble(ns()); }
    private char nc() { return (char)skip(); }
    
    private String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b)))
{ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }
    
    private int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }
    
    private int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static void tr(Object... o) { 
 System.out.println(Arrays.deepToString(o)); }
}









In Python3 :





MOD = 10**9 + 7

def lcm(lst):
    ans = 1
    for x in lst:
        ans = ans*x//gcd(ans, x)
    return ans

def gcd(a,b):
    if a<b:
        a, b = b, a
    while b > 0:
        a, b = b, a%b
    return a

def getsoltable(a, m, MOD=MOD):
    soltable = [1] + [0] * (len(a)*m-1)
    for x in a:
        oldsoltable = soltable
        soltable = list(soltable)
        for i in range(x, len(soltable)):
            soltable[i] = (oldsoltable[i] + soltable[i - x]) % MOD
    return soltable

def countsols(const, soltable, lcm):
    offset = const % lcm
    pts = soltable[offset::lcm]
    assert len(pts) == len(a)
    coef = polycoef(pts)
    return polyval(coef, const//lcm)

def polycoef(pts):
    coef = []
    for x, y in enumerate(pts):
        fact = descpower = 1
        for i, c in enumerate(coef):
            y -= descpower*c//fact
            descpower *= x - i
            fact *= i + 1
        coef.append(y)
    return coef
        
def polyval(coef, x):
    ans = 0
    fact = descpower = 1
    for i, c in enumerate(coef):
        ans += c * descpower * pow(fact, MOD-2, MOD)
        descpower = descpower * (x - i) % MOD
        fact *= i + 1
    return ans % MOD
        
n = int(input())
a = [1] + [int(fld) for fld in input().strip().split()]
L, R = [int(fld ) for fld in input().strip().split()]
m = lcm(a)
soltable = getsoltable(a, m)
print((countsols(R, soltable, m) - countsols(L-1, soltable, m)) % MOD)
                        








View More Similar Problems

Inserting a Node Into a Sorted Doubly Linked List

Given a reference to the head of a doubly-linked list and an integer ,data , create a new DoublyLinkedListNode object having data value data and insert it at the proper location to maintain the sort. Example head refers to the list 1 <-> 2 <-> 4 - > NULL. data = 3 Return a reference to the new list: 1 <-> 2 <-> 4 - > NULL , Function Description Complete the sortedInsert function

View Solution →

Reverse a doubly linked list

This challenge is part of a tutorial track by MyCodeSchool Given the pointer to the head node of a doubly linked list, reverse the order of the nodes in place. That is, change the next and prev pointers of the nodes so that the direction of the list is reversed. Return a reference to the head node of the reversed list. Note: The head node might be NULL to indicate that the list is empty.

View Solution →

Tree: Preorder Traversal

Complete the preorder function in the editor below, which has 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's preorder traversal as a single line of space-separated values. Input Format Our test code passes the root node of a binary tree to the preOrder function. Constraints 1 <= Nodes in the tree <= 500 Output Format Print the tree's

View Solution →

Tree: Postorder Traversal

Complete the postorder function in the editor below. It received 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's postorder traversal as a single line of space-separated values. Input Format Our test code passes the root node of a binary tree to the postorder function. Constraints 1 <= Nodes in the tree <= 500 Output Format Print the

View Solution →

Tree: Inorder Traversal

In this challenge, you are required to implement inorder traversal of a tree. Complete the inorder function in your editor below, which has 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's inorder traversal as a single line of space-separated values. Input Format Our hidden tester code passes the root node of a binary tree to your $inOrder* func

View Solution →

Tree: Height of a Binary Tree

The height of a binary tree is the number of edges between the tree's root and its furthest leaf. For example, the following binary tree is of height : image Function Description Complete the getHeight or height function in the editor. It must return the height of a binary tree as an integer. getHeight or height has the following parameter(s): root: a reference to the root of a binary

View Solution →