Counting the Ways


Problem Statement :


Little Walter likes playing with his toy scales. He has N types of weights. The ith weight type has weight ai. There are infinitely many weights of each type.

Recently, Walter defined a function, F(X), denoting the number of different ways to combine several weights so their total weight is equal to X. Ways are considered to be different if there is a type which has a different number of weights used in these two ways.

For example, if there are  types of weights with corresonding weights 1, 1, and 2, then there are 4 ways to get a total weight of 2:

1.Use 2 weights of type 1.
2.Use 2 weights of type 2.
3.Use 1 weight of type 1 and 1 weight of type 2.
4.Use 1 weight of type 3.
Given N, L, R, and a1,a2,...,aN, can you find the value of F(L)+F(L+1)+...+F(R)?

Input Format

The first line contains a single integer, N, denoting the number of types of weights.
The second line contains N space-separated integers describing the values of a1,a2,...,aN, respectively
The third line contains two space-separated integers denoting the respective values of L and R.

Constraints
1 <= N <= 10
0 < ai <= 10^5
a1*a2*...*aN <= 10^5
1 <= L <= R <=10^17



Solution :



title-img


                            Solution in C :

In C++ :





#include <bits/stdc++.h>
using namespace std;

const long long MOD = 1e9 + 7;

int n;
int a[10];
long long L, R;

const int N = 202000;
int dp0[N];
int dp1[N];

inline void add(int &x, int y) {
	x += y;
	if (x >= MOD) x -= MOD;
}

long long solve(long long v) {
	bitset<62> s(v);
	memset(dp0, 0, sizeof(dp0));
	dp0[0] = 1;

	for (int k = 0; k < 62; k++) {
		for (int i = 0; i < n; i++) {
			for (int j = N - a[i] - 1; j >= 0; j--) {
				add(dp0[j + a[i]], dp0[j]);
			}
		}
		if (s[k]) {
			add(dp0[0], dp0[1]);
			for (int i = 1; i < N - 1; i++) {
				dp0[i] = dp0[i + 1];
			}
		}
		memset(dp1, 0, sizeof(dp1));
		for (int i = 0; i < N; i++) {
			add(dp1[(i + 1) / 2], dp0[i]); 
		}
		swap(dp0, dp1);
	}

	return dp0[0];
}

int main() {
	cin >> n;
	for (int i = 0; i < n; i++) cin >> a[i];
	cin >> L >> R;

	int ans = solve(R) - solve(L - 1);
	if (ans < 0) ans += MOD;
	cout << ans << endl;
}









In Java :





import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class F {
    InputStream is;
    PrintWriter out;
    String INPUT = "";
    int mod = 1000000007;
    
    void solve()
    {
        int n = ni();
        int[] a = na(n);
        long L = nl(), R = nl();
        int Z = 1200000;
        long[] dp = new long[Z+1];
        Arrays.fill(dp, 1);
        long pe = 1;
        for(int v : a){
            pe *= v;
            for(int i = 0;i+v <= Z;i++){
                dp[i+v] += dp[i];
                if(dp[i+v] >= mod)dp[i+v] -= mod;
            }
        }
        int[][] fif = enumFIF(30, mod);
        long ret = 0;
        {
            long[] y = new long[12];
            for(int i = 0;i < 12;i++){
                y[i] = dp[(int)(R%pe+i*pe)];
            }
            ret += guessDirectly(mod, R/pe, fif, y);
        }
        {
            long[] y = new long[12];
            for(int i = 0;i < 12;i++){
                y[i] = dp[(int)((L-1)%pe+i*pe)];
            }
            ret -= guessDirectly(mod, (L-1)/pe, fif, y);
        }
        if(ret < 0)ret += mod;
        out.println(ret);
    }
    
    public static int[][] enumFIF(int n, int mod) {
        int[] f = new int[n + 1];
        int[] invf = new int[n + 1];
        f[0] = 1;
        for (int i = 1; i <= n; i++) {
            f[i] = (int) ((long) f[i - 1] * i % mod);
        }
        long a = f[n];
        long b = mod;
        long p = 1, q = 0;
        while (b > 0) {
            long c = a / b;
            long d;
            d = a;
            a = b;
            b = d % b;
            d = p;
            p = q;
            q = d - c * q;
        }
        invf[n] = (int) (p < 0 ? p + mod : p);
        for (int i = n - 1; i >= 0; i--) {
            invf[i] = (int) ((long) invf[i + 1] * (i + 1) % mod);
        }
        return new int[][] { f, invf };
    }
    
    public static long guessDirectly(long mod, long x, int[][] fif, long... y)
     {
        int n = y.length;
        if(0 <= x && x < n){
            return y[(int)x];
        }else if(x % mod - (n-1) <= 0){
            long mul = 1;
            for(int i = 0;i < n;i++){
                if((x-i)%mod == 0)continue;
                mul = mul * ((x-i)%mod) % mod;
            }
            long s = 0;
            long sig = 1;
            long big = 8L*mod*mod;
            for(int i = n-1;i >= 0;i--){
                if((x-i)%mod == 0){
                    s += fif[1][i] % mod * fif[1][n-1-i] % mod * y[i] * sig;
                    if(s >= big)s -= big;
                    if(s <= -big)s += big;
                }
                sig = -sig;
            }
            s %= mod;
            if(s < 0)s += mod;
            s = s * mul % mod;
            return s;
        }else{
            long mul = 1;
            for(int i = 0;i < n;i++){
                mul = mul * ((x-i)%mod)%mod;
            }
            long s = 0;
            long sig = 1;
            long big = 8L*mod*mod;
            for(int i = n-1;i >= 0;i--){
                s += invl(x-i, mod) * fif[1][i] % mod * fif[1][n-1-i] % mod * y[i] * sig;
                if(s >= big)s -= big;
                if(s <= -big)s += big;
                sig = -sig;
            }
            s %= mod;
            if(s < 0)s += mod;
            s = s * mul % mod;
            return s;
        }
    }

    public static long invl(long a, long mod) {
        long b = mod;
        long p = 1, q = 0;
        while (b > 0) {
            long c = a / b;
            long d;
            d = a;
            a = b;
            b = d % b;
            d = p;
            p = q;
            q = d - c * q;
        }
        return p < 0 ? p + mod : p;
    }
    void run() throws Exception
    {
        is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
        out = new PrintWriter(System.out);
        
        long s = System.currentTimeMillis();
        solve();
        out.flush();
        if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
    }
    
    public static void main(String[] args) throws Exception {
     new F().run(); }
    
    private byte[] inbuf = new byte[1024];
    private int lenbuf = 0, ptrbuf = 0;
    
    private int readByte()
    {
        if(lenbuf == -1)throw new InputMismatchException();
        if(ptrbuf >= lenbuf){
            ptrbuf = 0;
            try { lenbuf = is.read(inbuf); } catch (IOException e) { 
            throw new InputMismatchException(); }
            if(lenbuf <= 0)return -1;
        }
        return inbuf[ptrbuf++];
    }
    
    private boolean isSpaceChar(int c) { 
return !(c >= 33 && c <= 126); }
    private int skip() 
{ int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
    
    private double nd() { 
return Double.parseDouble(ns()); }
    private char nc() { return (char)skip(); }
    
    private String ns()
    {
        int b = skip();
        StringBuilder sb = new StringBuilder();
        while(!(isSpaceChar(b)))
{ // when nextLine, (isSpaceChar(b) && b != ' ')
            sb.appendCodePoint(b);
            b = readByte();
        }
        return sb.toString();
    }
    
    private char[] ns(int n)
    {
        char[] buf = new char[n];
        int b = skip(), p = 0;
        while(p < n && !(isSpaceChar(b))){
            buf[p++] = (char)b;
            b = readByte();
        }
        return n == p ? buf : Arrays.copyOf(buf, p);
    }
    
    private char[][] nm(int n, int m)
    {
        char[][] map = new char[n][];
        for(int i = 0;i < n;i++)map[i] = ns(m);
        return map;
    }
    
    private int[] na(int n)
    {
        int[] a = new int[n];
        for(int i = 0;i < n;i++)a[i] = ni();
        return a;
    }
    
    private int ni()
    {
        int num = 0, b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private long nl()
    {
        long num = 0;
        int b;
        boolean minus = false;
        while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
        if(b == '-'){
            minus = true;
            b = readByte();
        }
        
        while(true){
            if(b >= '0' && b <= '9'){
                num = num * 10 + (b - '0');
            }else{
                return minus ? -num : num;
            }
            b = readByte();
        }
    }
    
    private static void tr(Object... o) { 
 System.out.println(Arrays.deepToString(o)); }
}









In Python3 :





MOD = 10**9 + 7

def lcm(lst):
    ans = 1
    for x in lst:
        ans = ans*x//gcd(ans, x)
    return ans

def gcd(a,b):
    if a<b:
        a, b = b, a
    while b > 0:
        a, b = b, a%b
    return a

def getsoltable(a, m, MOD=MOD):
    soltable = [1] + [0] * (len(a)*m-1)
    for x in a:
        oldsoltable = soltable
        soltable = list(soltable)
        for i in range(x, len(soltable)):
            soltable[i] = (oldsoltable[i] + soltable[i - x]) % MOD
    return soltable

def countsols(const, soltable, lcm):
    offset = const % lcm
    pts = soltable[offset::lcm]
    assert len(pts) == len(a)
    coef = polycoef(pts)
    return polyval(coef, const//lcm)

def polycoef(pts):
    coef = []
    for x, y in enumerate(pts):
        fact = descpower = 1
        for i, c in enumerate(coef):
            y -= descpower*c//fact
            descpower *= x - i
            fact *= i + 1
        coef.append(y)
    return coef
        
def polyval(coef, x):
    ans = 0
    fact = descpower = 1
    for i, c in enumerate(coef):
        ans += c * descpower * pow(fact, MOD-2, MOD)
        descpower = descpower * (x - i) % MOD
        fact *= i + 1
    return ans % MOD
        
n = int(input())
a = [1] + [int(fld) for fld in input().strip().split()]
L, R = [int(fld ) for fld in input().strip().split()]
m = lcm(a)
soltable = getsoltable(a, m)
print((countsols(R, soltable, m) - countsols(L-1, soltable, m)) % MOD)
                        








View More Similar Problems

Pair Sums

Given an array, we define its value to be the value obtained by following these instructions: Write down all pairs of numbers from this array. Compute the product of each pair. Find the sum of all the products. For example, for a given array, for a given array [7,2 ,-1 ,2 ] Note that ( 7 , 2 ) is listed twice, one for each occurrence of 2. Given an array of integers, find the largest v

View Solution →

Lazy White Falcon

White Falcon just solved the data structure problem below using heavy-light decomposition. Can you help her find a new solution that doesn't require implementing any fancy techniques? There are 2 types of query operations that can be performed on a tree: 1 u x: Assign x as the value of node u. 2 u v: Print the sum of the node values in the unique path from node u to node v. Given a tree wi

View Solution →

Ticket to Ride

Simon received the board game Ticket to Ride as a birthday present. After playing it with his friends, he decides to come up with a strategy for the game. There are n cities on the map and n - 1 road plans. Each road plan consists of the following: Two cities which can be directly connected by a road. The length of the proposed road. The entire road plan is designed in such a way that if o

View Solution →

Heavy Light White Falcon

Our lazy white falcon finally decided to learn heavy-light decomposition. Her teacher gave an assignment for her to practice this new technique. Please help her by solving this problem. You are given a tree with N nodes and each node's value is initially 0. The problem asks you to operate the following two types of queries: "1 u x" assign x to the value of the node . "2 u v" print the maxim

View Solution →

Number Game on a Tree

Andy and Lily love playing games with numbers and trees. Today they have a tree consisting of n nodes and n -1 edges. Each edge i has an integer weight, wi. Before the game starts, Andy chooses an unordered pair of distinct nodes, ( u , v ), and uses all the edge weights present on the unique path from node u to node v to construct a list of numbers. For example, in the diagram below, Andy

View Solution →

Heavy Light 2 White Falcon

White Falcon was amazed by what she can do with heavy-light decomposition on trees. As a resut, she wants to improve her expertise on heavy-light decomposition. Her teacher gave her an another assignment which requires path updates. As always, White Falcon needs your help with the assignment. You are given a tree with N nodes and each node's value Vi is initially 0. Let's denote the path fr

View Solution →