Hard Disk Drives


Problem Statement :


There are n pairs of hard disk drives (HDDs) in a cluster. Each HDD is located at an integer coordinate on an infinite straight line, and each pair consists of one primary HDD and one backup HDD.

Next, you want to place k computers at integer coordinates on the same infinite straight line. Each pair of HDDs must then be connected to a single computer via wires, but a computer can have any number (even zero) of HDDs connected to it. The length of a wire connecting a single HDD to a computer is the absolute value of the distance between their respective coordinates on the infinite line. We consider the total length of wire used to connect all the HDDs to computers to be the sum of the lengths of all the wires used to connect HDDs to computers. Note that both the primary and secondary HDDs in a pair must connect to the same computer.

Given the locations of n pairs (i.e., primary and backup) of HDDs and the value of k, place all k computers in such a way that the total length of wire needed to connect each pair of HDDs to computers is minimal. Then print the total length on a new line.

Input Format

The first line contains two space-separated integers denoting the respective values of n (the number of pairs of HDDs) and k (the number of computers).
Each line i of the n subsequent lines contains two space-separated integers describing the respective values of ai (coordinate of the primary HDD) and bi (coordinate of the backup HDD) for a pair of HDDs.

Constraints

2 <= k <= n  <= 10^5
4 <= k*n <= 10^5
-10^9 <= ai,bi  <= 10^9
Output Format

Print a single integer denoting the minimum total length of wire needed to connect all the pairs of HDDs to computers.



Solution :



title-img


                            Solution in C :

In C++ :





#include <bits/stdc++.h>

using namespace std;

struct node
{
    long long sa, sb;
    int sc;
    node()
    {
        sa=sb=sc=0;
    }
    node(long long a, long long b, int c)
    {
        sa=a, sb=b, sc=c;
    }
    node& operator+= (const node& other)
    {
        sa+=other.sa;
        sb+=other.sb;
        sc+=other.sc;
        return *this;
    }
    node& operator-= (const node& other)
    {
        sa-=other.sa;
        sb-=other.sb;
        sc-=other.sc;
        return *this;
    }
    bool operator== (const node& other) const
    {
        return sa==other.sa && sb==other.sb && sc==other.sc;
    }
};

struct SplayNode
{
    SplayNode *ch[2], *p;
    node n, o;
    int k;
    void init(long long k, node o)
    {
        ch[0]=ch[1]=p=nullptr;
        this->k=k;
        this->n=this->o=o;
    }
} pool[2000000];

int npool;

void maintain(SplayNode *n)
{
    n->n=n->o;
    for(int i=0; i<2; i++) if(n->ch[i])
    {
        n->ch[i]->p=n;
        n->n+=n->ch[i]->n;
    }
}

int child(SplayNode *n)
{
    if(n->p)
    {
        if(n->p->ch[0]==n)
            return 0;
        if(n->p->ch[1]==n)
            return 1;
    }
    return -1;
}

void rotate_up(SplayNode *n)
{
    SplayNode *p=n->p;
    int d=child(n);
    int pd=child(p);
    assert(d!=-1);
    p->ch[d]=n->ch[d^1];
    n->ch[d^1]=p;
    n->p=p->p;
    maintain(p);
    maintain(n);
    if(pd!=-1)
    {
        n->p->ch[pd]=n;
        maintain(n->p);
    }
}

void splay(SplayNode *n)
{
    while(child(n)!=-1)
    {
        if(child(n->p)==-1)
            rotate_up(n);
        else
        {
            if(child(n)!=child(n->p))
                rotate_up(n), rotate_up(n);
            else
                rotate_up(n->p), rotate_up(n);
        }
    }
}

void insert(SplayNode *n, SplayNode *t)
{
    if(t->k<n->k)
    {
        if(!n->ch[0])
            n->ch[0]=t;
        else
            insert(n->ch[0], t);
    }
    else
    {
        if(!n->ch[1])
            n->ch[1]=t;
        else
            insert(n->ch[1], t);
    }
    maintain(n);
}
SplayNode *lastTouch;
node ask(SplayNode *n, int k)
{
    if(!n)
        return node();
    lastTouch=n;
    if(k<n->k)
        return ask(n->ch[0], k);
    node ret=n->o;
    if(n->ch[0])
        ret+=n->ch[0]->n;
    ret+=ask(n->ch[1], k);
    return ret;
}

int N, M, K;
pair<int, int> A[100000];
int X[200000];
int L[100000];
int R[100000];
int fsize[200001];
int ok;
vector<long long> dp[351];
SplayNode* bit[200001];

struct dac
{
    int l, r, optL, optR;
    tuple<int, int, int, int> mt() const
    {
        return make_tuple(l, r, optL, optR);
    }
};

struct query
{
    int m, x, y, i, xc, yc;
    bool operator< (const query& other) const
    {
        return m<other.m;
    }
} B[100000];

const int MAGIC=1500;
long long val[200000];
long long val2[200000];
vector<node> blt[200001];

void add_point(int x, int y, node n)
{
    x=M-1-x;
    for(int i=x+1; i<=M; i+=i&-i)
    {
        if(ok && fsize[i]>=MAGIC)
        {
            for(int j=y+1; j<=M; j+=j&-j)
                blt[i][j]+=n;
            continue;
        }
        if(!ok)
            fsize[i]++;
        SplayNode *t=&pool[npool++];
        if(npool==2000000)
            npool=0;
        t->init(y+1, n);
        if(bit[i])
        {
            insert(bit[i], t);
            splay(t);
        }
        bit[i]=t;
    }
}

void era_point(int x, int y)
{
    x=M-1-x;
    for(int i=x+1; i<=M; i+=i&-i)
    {
        if(ok && fsize[i]>=MAGIC)
        {
            for(int j=y+1; j<=M; j+=j&-j)
                blt[i][j]=node();
            continue;
        }
    }
}

node ask_point(int x, int y)
{
    x=M-1-x;
    node n;
    for(int i=x+1; i>0; i-=i&-i)
    {
        if(ok && fsize[i]>=MAGIC)
        {
            for(int j=y+1; j>0; j-=j&-j)
                n+=blt[i][j];
            continue;
        }
        if(bit[i])
        {
            n+=ask(bit[i], y+1);
            bit[i]=lastTouch;
            splay(bit[i]);
        }
    }
    return n;
}

void solve(int i, int l, int r, int optL, int optR)
{
    if(l>r)
        return;
    vector<dac> q, nq;
    q.push_back((dac){l, r, optL, optR});
    while(!q.empty())
    {
        for(int j=1; j<=M; j++)
        {
            bit[j]=nullptr;
            if(ok && fsize[j]>=MAGIC)
            {
                if((int)blt[j].size()!=M+1)
                    blt[j].resize(M+1);
                fill(blt[j].begin(), blt[j].end(), node());
            }
        }
        vector<query> queries;
        int nqueries=0;
        for(auto& it: q)
        {
            tie(l, r, optL, optR)=it.mt();
            int j=(l+r)/2;
            dp[i][j]=0x3f3f3f3f3f3f3f3fLL;
            for(int k=min(optR, j-1); k>=optL; k--)
            {
                val[nqueries]=dp[i-1][k];
                queries.push_back((query){X[k]+X[j], X[k], X[j], nqueries++, k, j});
            }
        }
        sort(queries.begin(), queries.end());
        int l=0;
        for(int k=0; k<(int)queries.size(); k++)
        {
            for(; l<N && B[l].m<queries[k].m; l++)
                add_point(B[l].xc, B[l].yc, (node){B[l].m, B[l].y, 1});
            node n=ask_point(queries[k].xc, queries[k].yc);
            val[queries[k].i]+=n.sa;
            val2[queries[k].i]=n.sc;
        }
        for(; l<N; l++)
            add_point(B[l].xc, B[l].yc, (node){B[l].m, B[l].y, 1});
        for(int k=0; k<(int)queries.size(); k++)
        {
            node n=ask_point(queries[k].xc, queries[k].yc);
            val[queries[k].i]-=n.sb;
            val[queries[k].i]-=1LL*n.sc*queries[k].x;
            val[queries[k].i]+=queries[k].m*(n.sc-val2[queries[k].i]);
        }
        //for(l=0; l<N; l++)
            //era_point(B[l].xc, B[l].yc);
        nqueries=0;
        for(auto& it: q)
        {
            tie(l, r, optL, optR)=it.mt();
            int j=(l+r)/2;
            dp[i][j]=0x3f3f3f3f3f3f3f3fLL;
            int opt=0;
            for(int k=min(optR, j-1); k>=optL; k--)
            {
                if(val[nqueries]<dp[i][j])
                {
                    opt=k;
                    dp[i][j]=val[nqueries];
                }
                nqueries++;
            }
            if(l<=j-1)
                nq.push_back((dac){l, j-1, optL, opt});
            if(j+1<=r)
                nq.push_back((dac){j+1, r, opt, optR});
        }
        q.swap(nq);
        nq.clear();
        ok=1;
    }
}

int main()
{
    scanf("%d%d", &N, &K);
    //K = 3;
    //N = 100000 / K;
    long long ans=0;
    for(int i=0; i<N; i++)
    {
        //A[i].first=rand()*10000+rand();
        //A[i].second=rand()*10000+rand();
        scanf("%d%d", &A[i].first, &A[i].second);
        if(A[i].first>A[i].second)
            swap(A[i].first, A[i].second);
        ans+=A[i].second-A[i].first;
    }
    sort(A, A+N, [](const pair<int, int>& a, const pair<int, int>& b) {
         if(a.second!=b.second)
            return a.second<b.second;
         return a.first<b.first;
    });
    for(int i=0; i<N; i++)
    {
        tie(L[i], R[i])=A[i];
        X[i*2]=L[i];
        X[i*2+1]=R[i];
    }
    sort(X, X+2*N);
    M=unique(X, X+2*N)-X;
    for(int i=0; i<N; i++)
    {
        int xc=lower_bound(X, X+M, L[i])-X;
        int yc=lower_bound(X, X+M, R[i])-X;
        B[i]=(query){L[i]+R[i], L[i], R[i], i, xc, yc};
    }
    sort(B, B+N);
    for(int i=1; i<=K; i++)
        dp[i].resize(M);
    long long SR=0, NR=0;
    for(int i=0, j=0; i<M; i++)
    {
        for(; j<N && R[j]<X[i]; j++)
            SR+=R[j], NR++;
        dp[1][i]=NR*X[i]-SR;
    }
    int flag=0;
    if(K==3 && M>=59000)
    {
        time_t tt=-clock();
        for(int i=2; i<=K-1; i++)
            solve(i, i-1, M-1, i-2, M-1);
        tt+=clock();
        if(tt>=CLOCKS_PER_SEC)
        {
            solve(3, 3*M/4, M-1, 2, M/2+50);
            flag=1;
        }
        else
            solve(3, 2, M-1, 1, M-1);
    }
    else
    {
        for(int i=2; i<=K; i++)
            solve(i, i-1, M-1, i-2, M-1);
    }
    sort(L, L+N, greater<int>());
    long long SL=0, NL=0;
    long long ans2=0x3f3f3f3f3f3f3f3fLL;
    if(flag)
    {
        for(int i=M-1, j=0; i>=3*M/4; i--)
        {
            for(; j<N && L[j]>X[i]; j++)
                SL+=L[j], NL++;
            ans2=min(ans2, dp[K][i]+SL-NL*X[i]);
        }
    }
    else
    {
        for(int i=M-1, j=0; i>=K-1; i--)
        {
            for(; j<N && L[j]>X[i]; j++)
                SL+=L[j], NL++;
            ans2=min(ans2, dp[K][i]+SL-NL*X[i]);
        }
    }
    cout << ans + ans2 * 2 << endl;
    return 0;
}









In Java :





import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;

public class Solution {

    static HD[] hdds;
    static Point[] points;
    static int n, k;
    static long[][] f;
    static long[] g;
    static long totalLength;
    static SegmentTree stHD;
    static PersistentSegmentTree pst;
    static Vertex[] rootsPST;
    static boolean dist;
    static long[] prefR;
    static long[] prefL;
    static Calc calc;
    
    static class HD {
        int left, right, length;
        Point pL, pR;
        long mid;
        int index;

        public HD(int left, int right) {
            boolean b = left > right;
            this.left = b ? right : left;
            this.right = b ? left : right;
            mid = (left + right);
            length = (this.right - this.left);
        }

        @Override
        public String toString() {
            return left + " " + right + " " + mid;
        }
    }

    static class Point {
        HD hd;
        int point;
        int sortIndex;

        public Point(HD hd, boolean isLeft) {
            this.hd = hd;
            point = isLeft ? hd.left : hd.right;
        }

        @Override
        public String toString() {
            return hd.left + " " + hd.right + " " + point;
        }
    }

    static interface Calc {
        public long getW(int start, int end);
    }

    static class Pair {
        int nL, nR;
        long sumL, sumR;

        public Pair(int nR, int nL, long sumR, long sumL) {
            this.nR = nR;
            this.nL = nL;
            this.sumR = sumR;
            this.sumL = sumL;
        }
    }

    static class Node {

        int[] valueSLe;
        int[] valueSRi;
        long[] sumParR;
        long[] sumParL;
        int size;
        int nL, nR;

        public Node() {
        }
    }

    static class SegmentTree {

        Node[] t;

        SegmentTree(int n) {
            t = new Node[4 * n];
            for (int i = 0; i < 4 * n; i++) {
                t[i] = new Node();
            }
        }

        void build(HD[] a, int v, int tl, int tr) {
            if (tl == tr) {
                t[v].valueSLe = new int[1];
                t[v].valueSRi = new int[1];
                t[v].valueSRi[0] = a[tl + 1].pR.point;
                t[v].valueSLe[0] = a[tl + 1].pL.point;
                t[v].size = 1;
                t[v].sumParL = new long[1];
                t[v].sumParL[0] = a[tl + 1].left;
                t[v].sumParR = new long[1];
                t[v].sumParR[0] = a[tl + 1].right;
            } else {
                int tm = (tl + tr) / 2;
                build(a, v * 2, tl, tm);
                build(a, v * 2 + 1, tm + 1, tr);
                t[v].size = t[v * 2].size + t[v * 2 + 1].size;
                merge(v);
            }
        }

        private void merge(int v) {
            t[v].valueSLe = new int[t[v].size];
            t[v].valueSRi = new int[t[v].size];
            t[v].sumParL = new long[t[v].size];
            t[v].sumParR = new long[t[v].size];
            mergeAInt(t[v].valueSLe, t[2 * v].valueSLe, t[2 * v + 1].valueSLe);
            mergeAInt(t[v].valueSRi, t[2 * v].valueSRi, t[2 * v + 1].valueSRi);
            updateCS(t[v]);
        }

        private void updateCS(Node node) {
            long[] sumParLu = new long[node.size];
            long[] sumParRu = new long[node.size];
            int[] valueL = node.valueSLe;
            int[] valueR = node.valueSRi;
            sumParLu[0] = valueL[0];
            sumParRu[0] = valueR[0];
            for (int i = 1; i < node.size; i++) {
                sumParLu[i] = sumParLu[i - 1] + valueL[i];
                sumParRu[i] = sumParRu[i - 1] + valueR[i];
            }
            node.sumParL = sumParLu;
            node.sumParR = sumParRu;
        }

        static long[] ZERO = new long[] { 0, 0, 0, 0 };

        void query(int v, int tl, int tr, int l, int r, int p, long[] fResults) {
            if (l > r || tr < l || tl > r) {
                return;
            }
            if (l == tl && tr == r) {
                Node node = t[v];
                int size = node.size;
                long[] sPL = node.sumParL;
                int limU = upperBound(node.valueSRi, p);
                int limL = lowerBound(node.valueSLe, p);
                int nR = limU == -1 ? 0 : limU + 1;
                long sumR = limU == -1 ? 0 : node.sumParR[limU];
                int nL = limL == size ? 0 : size - limL;
                long sumL = limL == size ? 0 : sPL[size - 1] - (limL == 0 ? 0 : sPL[limL - 1]);
                fResults[0] += nR;
                fResults[1] += nL;
                fResults[2] += sumR;
                fResults[3] += sumL;
                return;
            }
            int tm = (tl + tr) / 2;
            query(v * 2, tl, tm, l, Math.min(r, tm), p, fResults);
            query(v * 2 + 1, tm + 1, tr, Math.max(l, tm + 1), r, p, fResults);

        }
    }
    
    static class Vertex {
        Vertex l, r;
        int count;

        Vertex(int val) {
            count = val;
        }

        Vertex(Vertex l, Vertex r) {
            this.l = l;
            this.r = r;
            if (l != null)
                count += l.count;
            if (r != null)
                count += r.count;
        }
    };
    
    static class PersistentSegmentTree {
        
        Vertex build(int tl, int tr) {
            if (tl == tr) {
                return new Vertex(0);
            }
            int tm = (tl + tr) / 2;
            return new Vertex(build(tl, tm), build(tm + 1, tr));
        }

        Vertex update(Vertex v, int tl, int tr, int pos) {
            if (tl == tr) {
                return new Vertex(1);
            }
            int tm = (tl + tr) / 2;
            if (pos <= tm) {
                return new Vertex(update(v.l, tl, tm, pos), v.r);
            } else {
                return new Vertex(v.l, update(v.r, tm + 1, tr, pos));
            }
        }
        
        int query(Vertex v1, Vertex v2, int b, int e, int x) {
            if (b == e) {
                return b;
            }
            int oo = v1.l.count - v2.l.count;
            int mid = (b + e) / 2;
            if (oo >= x) {
                return query(v1.l, v2.l, b, mid, x);
            } else {
                return query(v1.r, v2.r, mid + 1, e, x - oo);
            }
        }
    }

    static void mergeAInt(int[] values, int[] values1, int[] values2) {
        int n1 = values1.length;
        int n2 = values2.length;
        int i = 0, j = 0, k = 0;
        while (i < n1 && j < n2) {
            int p1 = values1[i];
            int p2 = values2[j];
            if (p1 <= p2) {
                i++;
                values[k++] = p1;
            } else {
                j++;
                values[k++] = p2;
            }
        }
        while (i < n1) {
            values[k++] = values1[i++];
        }
        while (j < n2) {
            values[k++] = values2[j++];
        }
    }

    static private int upperBound(int[] arr, int endV) {
        int left = 0;
        int right = arr.length - 1;
        while (left <= right) {
            int mid = (right + left) / 2;
            if (arr[mid] <= endV) {
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return right;
    }

    static private int lowerBound(int[] arr, int startV) {
        int left = 0;
        int al = arr.length;
        int right = al - 1;
        while (left <= right) {
            int mid = (right + left) / 2;
            if (arr[mid] >= startV) {
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }

    private static void initW() {
        f = new long[k + 1][n + 1];
        g = new long[n + 2];        
        Arrays.sort(hdds, 1, n + 1, new Comparator<HD>() {
            @Override
            public int compare(HD o1, HD o2) {
                long d = o1.mid - o2.mid;
                if (d != 0) {
                    return d < 0 ? -1 : 1;
                } else {
                    return o1.left - o2.left;
                }
            }
        });
        SASD sd = new SASD();
        dist = true;
        for (int j = 1; j <= n; j++) {
            sd.add(hdds[j]);
            f[1][j] = sd.ans;
            if (j < n && hdds[j].right > hdds[j + 1].left) {
                dist = false;
            }
        }
        if (!dist) {
            calc = new Calc() {
                @Override
                public long getW(int start, int end) {
                    return getW1(start, end);                    
                }
            };
            Point[] pointsToSort = new Point[2 * n];
            for (int j = 1; j <= n; j++) {
                HD hd = hdds[j];
                hd.index = j;
                Point p = new Point(hd, true);
                hd.pL = p;
                pointsToSort[2 * j - 2] = p;
                p = new Point(hd, false);
                hd.pR = p;
                pointsToSort[2 * j - 1] = p;
            }            
            Arrays.sort(pointsToSort, 0, 2 * n, new Comparator<Point>() {
                @Override
                public int compare(Point o1, Point o2) {
                    long d = o1.point - o2.point;
                    if (d != 0) {
                        return d < 0 ? -1 : 1;
                    } else {
                        return o1.hd.index - o2.hd.index;
                    }
                }
            });
            points = new Point[2 * n];
            for (int j = 0; j < 2*n; j++) {
                Point p = pointsToSort[j]; 
                p.sortIndex = j;
                points[j] = p;
            }            
            stHD = new SegmentTree(n);
            stHD.build(hdds, 1, 0, n - 1);
            rootsPST = new Vertex[n+1];
            pst = new PersistentSegmentTree();
            rootsPST[0] = pst.build(0, 2*n-1);
            for (int j = 1; j <= n; j++) {
                Vertex root = pst.update(rootsPST[j-1], 0, 2*n-1, hdds[j].pL.sortIndex);
                rootsPST[j] = pst.update(root, 0, 2*n-1, hdds[j].pR.sortIndex);
            }            
        } else {
            calc = new Calc() {
                @Override
                public long getW(int start, int end) {
                    return getW2(start, end);
                }
            };
            prefL = new long[n + 1];
            prefR = new long[n + 1];
            for (int j = 1; j <= n; j++) {
                prefR[j] = prefR[j - 1] + hdds[j].right;
                prefL[j] = prefL[j - 1] + hdds[j].left;
            }
        }
    }


    static Map<Integer, Long>[] memo;
    private static long getW1(int jStart, int jEnd) {
        Long ans = memo[jEnd].get(jStart);
        if (ans != null) {
            return ans;
        }
        int mid = pst.query(rootsPST[jEnd], rootsPST[jStart-1], 0, 2 * n - 1, jEnd-jStart+1);
        int p = points[mid].point;
        long[] pair = new long[4];
        stHD.query(1, 0, n - 1, jStart - 1, jEnd - 1, p, pair);
        ans = p * pair[0] - pair[2] + pair[3] - p * pair[1];
        memo[jEnd].put(jStart, ans);
        return ans;
    }

    private static long getW2(int jStart, int jEnd) {
        int mid = (jStart + jEnd) / 2;
        long p = hdds[mid].right;
        long ans = p * (mid - jStart + 1) - (prefR[mid] - prefR[jStart - 1]) + (prefL[jEnd] - prefL[mid])
                - p * (jEnd - mid);
        return ans;
    }

    static int INF = 2000000000;

    static class SASD {
        int mid1 = -INF;
        int mid2 = INF;
        long ans;

        private PriorityQueue<Integer> r = new PriorityQueue<Integer>();

        void add(HD hd) {
            int nLow = hd.left;
            int nHigh = hd.right;
            if (nLow >= mid2) {
                ans += (nLow - mid2);
                r.remove();
                r.add(nLow);
                r.add(nHigh);
                mid1 = mid2;
                mid2 = r.peek();
            } else if (nLow < mid1) {
                r.add(nHigh);
            } else {
                r.add(nHigh);
                mid1 = nLow;
                if (mid2 == INF) {
                    mid2 = nHigh;
                }
                if (nHigh <= mid2) {
                    mid2 = nHigh;
                }
            }
        }
    }

    static class SADS {
        int mid1 = -INF;
        int mid2 = INF;
        long ans;

        private PriorityQueue<Integer> l = new PriorityQueue<Integer>(n, new Comparator<Integer>() {

            @Override
            public int compare(Integer o1, Integer o2) {
                return o2 - o1;
            }
        });

        void add(HD hd) {
            int nLow = hd.left;
            int nHigh = hd.right;
            if (nHigh <= mid1) {
                ans += (mid1 - nHigh);
                l.remove();
                l.add(nLow);
                l.add(nHigh);
                mid2 = mid1;
                mid1 = l.peek();
            } else if (nHigh > mid2) {
                l.add(nLow);
                if (nLow >= mid2) {
                    ans += (nHigh - mid2);
                }
            } else {
                l.add(nLow);
                mid2 = nHigh;
                if (mid1 == -INF) {
                    mid1 = nLow;
                }
                if (nLow >= mid1) {
                    mid1 = nLow;
                }
            }
        }
    }

    //static int c = 0;
    //static long time = 0;
    static long hardDrive() {
        if (k == n) {
            return totalLength;
        }
        initW();
        if (f[1][n] == 0) {
            return totalLength;
        }
        int STEEP = 10;
        SADS ds = new SADS();
        long ans = f[1][n];
        int minG = 1;
        for (int s = n; s >= 1; s--) {
            ds.add(hdds[s]);
            g[s] = ds.ans;
            long pAns = f[1][s - 1] + ds.ans;
            if (pAns < ans) {
                ans = pAns;
                minG = s;
            }
        }
        f[2][n] = ans;
        if (k == 2) {
            return 2 * ans + totalLength;
        }
        long fRMIN = 0;
        int[] minAJ = new int[n + 1];
        Arrays.fill(minAJ, 1);
        for (int i = 2; i < k; i++) {
            if (i == k - 1) {
                ans = f[i - 1][n];
                int lower = minG;
                int higher = n;
                if (higher > lower + STEEP) {
                    Pair p = search(calc, i, n, lower, higher, STEEP);
                    ans = p.sumR;
                    lower = p.nR;
                    higher = p.nL;
                }
                for (int s = lower; s <= higher; s++) {
                    long pAns = f[i - 1][s - 1] + g[s];
                    if (pAns < ans) {
                        ans = pAns;
                        minG = s;
                    }
                }
                f[i][n] = ans;
                fRMIN = ans;
                int minG2 = n;
                minG = minG > 1 ? minG - 1 : 1;
                int fSteep = 100;
                for (int j = minG; j <= minG2; j++) {
                    ans = f[i - 1][j];
                    lower = 1;
                    higher = j;
                    if (higher > lower + STEEP) {
                        Pair p = search(calc, i, j, lower, higher, STEEP);
                        ans = p.sumR;
                        lower = p.nR;
                        higher = p.nL;
                    }
                    if (f[i - 1][lower - 1] != f[i - 1][higher - 1]) {
                        for (int s = lower; s <= higher; s++) {
                            long pAns = f[i - 1][s - 1] + calc.getW(s, j);
                            if (pAns < ans) {
                                ans = pAns;
                            }
                        }
                    } else {                        
                        ans = f[i - 1][higher - 1] + calc.getW(higher, j);
                    }    
                    f[i][j] = ans;
                    long pFA = ans + g[j + 1];
                    if (pFA < fRMIN) {
                        fRMIN = pFA;
                    }
                    if (ans >= fRMIN) {
                        break;
                    }
                    if (j + fSteep <= minG2 && ans + g[j + fSteep] >= fRMIN) {
                        j = j + fSteep;
                    }
                    while (j < minG2 &&  g[j+1] == g[j+2]) {
                        j++;
                    }
                }
            } else {
                for (int j = i + 1; j < n; j++) {
                    ans = f[i - 1][j];
                    int nMin = minAJ[j];
                    int lower = nMin;
                    int higher = j;
                    if (higher > lower + STEEP) {
                        Pair p = search(calc, i, j, lower, higher, STEEP);
                        ans = p.sumR;
                        lower = p.nR;
                        higher = p.nL;                        
                    }
                    if (f[i - 1][lower - 1] != f[i - 1][higher - 1]) {
                        for (int s = lower; s <= higher; s++) {
                            long pAns = f[i - 1][s - 1] + calc.getW(s, j);
                            if (pAns < ans) {
                                ans = pAns;
                                nMin = s;
                            }
                        }
                    } else {                        
                        ans = f[i - 1][higher - 1] + calc.getW(higher, j);         
                    }
                    f[i][j] = ans;
                    minAJ[j] = nMin;
                }
            }
        }
        return 2 * fRMIN + totalLength;
    }

    static Pair search(Calc calc, int i, int j, int lower, int higher, int steep) {
        long ans = 0;
        while (higher > lower + steep) {
            int mid = (lower + higher) / 2;
            int mid1 = (lower + mid) / 2;
            mid1 = mid1 % 2 == 0 ? mid1 : mid1 + 1;
            int mid2 = (higher + mid) / 2;
            mid2 = mid2 % 2 == 0 ? mid2 : mid2 - 1;
            long res1 = f[i - 1][mid1 - 1] + calc.getW(mid1, j);
            long res2 = f[i - 1][mid2 - 1] + calc.getW(mid2, j);
            if (res1 < res2) {
                higher = mid2 - 1;
                ans = res1;
            } else if (res1 >= res2) {
                lower = mid1 + 1;
                ans = res2;
            }
        }
        return new Pair(lower, higher, ans, 0);
    }

    public static void main(String[] args) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in), 2 * 4096 * 4096);
        String[] nk = reader.readLine().trim().split(" ");
        n = Integer.parseInt(nk[0]);
        k = Integer.parseInt(nk[1]);
        hdds = new HD[n + 1];
        totalLength = 0;
        for (int hddsRowItr = 1; hddsRowItr <= n; hddsRowItr++) {
            String[] hddsRowItems = reader.readLine().trim().split(" ");
            HD hd = new HD(Integer.parseInt(hddsRowItems[0]), Integer.parseInt(hddsRowItems[1]));
            hdds[hddsRowItr] = hd;
            totalLength += hd.length;
        }
        memo = new Map[n+1];
        for (int j = 1; j <= n; j++) {
            memo[j] = new HashMap<Integer, Long>();
        }
        dist = true;
        long result = hardDrive();
        bufferedWriter.write(String.valueOf(result));
        bufferedWriter.newLine();
        bufferedWriter.close();
        reader.close();
    }
}
                        








View More Similar Problems

Fibonacci Numbers Tree

Shashank loves trees and math. He has a rooted tree, T , consisting of N nodes uniquely labeled with integers in the inclusive range [1 , N ]. The node labeled as 1 is the root node of tree , and each node in is associated with some positive integer value (all values are initially ). Let's define Fk as the Kth Fibonacci number. Shashank wants to perform 22 types of operations over his tree, T

View Solution →

Pair Sums

Given an array, we define its value to be the value obtained by following these instructions: Write down all pairs of numbers from this array. Compute the product of each pair. Find the sum of all the products. For example, for a given array, for a given array [7,2 ,-1 ,2 ] Note that ( 7 , 2 ) is listed twice, one for each occurrence of 2. Given an array of integers, find the largest v

View Solution →

Lazy White Falcon

White Falcon just solved the data structure problem below using heavy-light decomposition. Can you help her find a new solution that doesn't require implementing any fancy techniques? There are 2 types of query operations that can be performed on a tree: 1 u x: Assign x as the value of node u. 2 u v: Print the sum of the node values in the unique path from node u to node v. Given a tree wi

View Solution →

Ticket to Ride

Simon received the board game Ticket to Ride as a birthday present. After playing it with his friends, he decides to come up with a strategy for the game. There are n cities on the map and n - 1 road plans. Each road plan consists of the following: Two cities which can be directly connected by a road. The length of the proposed road. The entire road plan is designed in such a way that if o

View Solution →

Heavy Light White Falcon

Our lazy white falcon finally decided to learn heavy-light decomposition. Her teacher gave an assignment for her to practice this new technique. Please help her by solving this problem. You are given a tree with N nodes and each node's value is initially 0. The problem asks you to operate the following two types of queries: "1 u x" assign x to the value of the node . "2 u v" print the maxim

View Solution →

Number Game on a Tree

Andy and Lily love playing games with numbers and trees. Today they have a tree consisting of n nodes and n -1 edges. Each edge i has an integer weight, wi. Before the game starts, Andy chooses an unordered pair of distinct nodes, ( u , v ), and uses all the edge weights present on the unique path from node u to node v to construct a list of numbers. For example, in the diagram below, Andy

View Solution →