Hard Disk Drives


Problem Statement :


There are n pairs of hard disk drives (HDDs) in a cluster. Each HDD is located at an integer coordinate on an infinite straight line, and each pair consists of one primary HDD and one backup HDD.

Next, you want to place k computers at integer coordinates on the same infinite straight line. Each pair of HDDs must then be connected to a single computer via wires, but a computer can have any number (even zero) of HDDs connected to it. The length of a wire connecting a single HDD to a computer is the absolute value of the distance between their respective coordinates on the infinite line. We consider the total length of wire used to connect all the HDDs to computers to be the sum of the lengths of all the wires used to connect HDDs to computers. Note that both the primary and secondary HDDs in a pair must connect to the same computer.

Given the locations of n pairs (i.e., primary and backup) of HDDs and the value of k, place all k computers in such a way that the total length of wire needed to connect each pair of HDDs to computers is minimal. Then print the total length on a new line.

Input Format

The first line contains two space-separated integers denoting the respective values of n (the number of pairs of HDDs) and k (the number of computers).
Each line i of the n subsequent lines contains two space-separated integers describing the respective values of ai (coordinate of the primary HDD) and bi (coordinate of the backup HDD) for a pair of HDDs.

Constraints

2 <= k <= n  <= 10^5
4 <= k*n <= 10^5
-10^9 <= ai,bi  <= 10^9
Output Format

Print a single integer denoting the minimum total length of wire needed to connect all the pairs of HDDs to computers.



Solution :



title-img


                            Solution in C :

In C++ :





#include <bits/stdc++.h>

using namespace std;

struct node
{
    long long sa, sb;
    int sc;
    node()
    {
        sa=sb=sc=0;
    }
    node(long long a, long long b, int c)
    {
        sa=a, sb=b, sc=c;
    }
    node& operator+= (const node& other)
    {
        sa+=other.sa;
        sb+=other.sb;
        sc+=other.sc;
        return *this;
    }
    node& operator-= (const node& other)
    {
        sa-=other.sa;
        sb-=other.sb;
        sc-=other.sc;
        return *this;
    }
    bool operator== (const node& other) const
    {
        return sa==other.sa && sb==other.sb && sc==other.sc;
    }
};

struct SplayNode
{
    SplayNode *ch[2], *p;
    node n, o;
    int k;
    void init(long long k, node o)
    {
        ch[0]=ch[1]=p=nullptr;
        this->k=k;
        this->n=this->o=o;
    }
} pool[2000000];

int npool;

void maintain(SplayNode *n)
{
    n->n=n->o;
    for(int i=0; i<2; i++) if(n->ch[i])
    {
        n->ch[i]->p=n;
        n->n+=n->ch[i]->n;
    }
}

int child(SplayNode *n)
{
    if(n->p)
    {
        if(n->p->ch[0]==n)
            return 0;
        if(n->p->ch[1]==n)
            return 1;
    }
    return -1;
}

void rotate_up(SplayNode *n)
{
    SplayNode *p=n->p;
    int d=child(n);
    int pd=child(p);
    assert(d!=-1);
    p->ch[d]=n->ch[d^1];
    n->ch[d^1]=p;
    n->p=p->p;
    maintain(p);
    maintain(n);
    if(pd!=-1)
    {
        n->p->ch[pd]=n;
        maintain(n->p);
    }
}

void splay(SplayNode *n)
{
    while(child(n)!=-1)
    {
        if(child(n->p)==-1)
            rotate_up(n);
        else
        {
            if(child(n)!=child(n->p))
                rotate_up(n), rotate_up(n);
            else
                rotate_up(n->p), rotate_up(n);
        }
    }
}

void insert(SplayNode *n, SplayNode *t)
{
    if(t->k<n->k)
    {
        if(!n->ch[0])
            n->ch[0]=t;
        else
            insert(n->ch[0], t);
    }
    else
    {
        if(!n->ch[1])
            n->ch[1]=t;
        else
            insert(n->ch[1], t);
    }
    maintain(n);
}
SplayNode *lastTouch;
node ask(SplayNode *n, int k)
{
    if(!n)
        return node();
    lastTouch=n;
    if(k<n->k)
        return ask(n->ch[0], k);
    node ret=n->o;
    if(n->ch[0])
        ret+=n->ch[0]->n;
    ret+=ask(n->ch[1], k);
    return ret;
}

int N, M, K;
pair<int, int> A[100000];
int X[200000];
int L[100000];
int R[100000];
int fsize[200001];
int ok;
vector<long long> dp[351];
SplayNode* bit[200001];

struct dac
{
    int l, r, optL, optR;
    tuple<int, int, int, int> mt() const
    {
        return make_tuple(l, r, optL, optR);
    }
};

struct query
{
    int m, x, y, i, xc, yc;
    bool operator< (const query& other) const
    {
        return m<other.m;
    }
} B[100000];

const int MAGIC=1500;
long long val[200000];
long long val2[200000];
vector<node> blt[200001];

void add_point(int x, int y, node n)
{
    x=M-1-x;
    for(int i=x+1; i<=M; i+=i&-i)
    {
        if(ok && fsize[i]>=MAGIC)
        {
            for(int j=y+1; j<=M; j+=j&-j)
                blt[i][j]+=n;
            continue;
        }
        if(!ok)
            fsize[i]++;
        SplayNode *t=&pool[npool++];
        if(npool==2000000)
            npool=0;
        t->init(y+1, n);
        if(bit[i])
        {
            insert(bit[i], t);
            splay(t);
        }
        bit[i]=t;
    }
}

void era_point(int x, int y)
{
    x=M-1-x;
    for(int i=x+1; i<=M; i+=i&-i)
    {
        if(ok && fsize[i]>=MAGIC)
        {
            for(int j=y+1; j<=M; j+=j&-j)
                blt[i][j]=node();
            continue;
        }
    }
}

node ask_point(int x, int y)
{
    x=M-1-x;
    node n;
    for(int i=x+1; i>0; i-=i&-i)
    {
        if(ok && fsize[i]>=MAGIC)
        {
            for(int j=y+1; j>0; j-=j&-j)
                n+=blt[i][j];
            continue;
        }
        if(bit[i])
        {
            n+=ask(bit[i], y+1);
            bit[i]=lastTouch;
            splay(bit[i]);
        }
    }
    return n;
}

void solve(int i, int l, int r, int optL, int optR)
{
    if(l>r)
        return;
    vector<dac> q, nq;
    q.push_back((dac){l, r, optL, optR});
    while(!q.empty())
    {
        for(int j=1; j<=M; j++)
        {
            bit[j]=nullptr;
            if(ok && fsize[j]>=MAGIC)
            {
                if((int)blt[j].size()!=M+1)
                    blt[j].resize(M+1);
                fill(blt[j].begin(), blt[j].end(), node());
            }
        }
        vector<query> queries;
        int nqueries=0;
        for(auto& it: q)
        {
            tie(l, r, optL, optR)=it.mt();
            int j=(l+r)/2;
            dp[i][j]=0x3f3f3f3f3f3f3f3fLL;
            for(int k=min(optR, j-1); k>=optL; k--)
            {
                val[nqueries]=dp[i-1][k];
                queries.push_back((query){X[k]+X[j], X[k], X[j], nqueries++, k, j});
            }
        }
        sort(queries.begin(), queries.end());
        int l=0;
        for(int k=0; k<(int)queries.size(); k++)
        {
            for(; l<N && B[l].m<queries[k].m; l++)
                add_point(B[l].xc, B[l].yc, (node){B[l].m, B[l].y, 1});
            node n=ask_point(queries[k].xc, queries[k].yc);
            val[queries[k].i]+=n.sa;
            val2[queries[k].i]=n.sc;
        }
        for(; l<N; l++)
            add_point(B[l].xc, B[l].yc, (node){B[l].m, B[l].y, 1});
        for(int k=0; k<(int)queries.size(); k++)
        {
            node n=ask_point(queries[k].xc, queries[k].yc);
            val[queries[k].i]-=n.sb;
            val[queries[k].i]-=1LL*n.sc*queries[k].x;
            val[queries[k].i]+=queries[k].m*(n.sc-val2[queries[k].i]);
        }
        //for(l=0; l<N; l++)
            //era_point(B[l].xc, B[l].yc);
        nqueries=0;
        for(auto& it: q)
        {
            tie(l, r, optL, optR)=it.mt();
            int j=(l+r)/2;
            dp[i][j]=0x3f3f3f3f3f3f3f3fLL;
            int opt=0;
            for(int k=min(optR, j-1); k>=optL; k--)
            {
                if(val[nqueries]<dp[i][j])
                {
                    opt=k;
                    dp[i][j]=val[nqueries];
                }
                nqueries++;
            }
            if(l<=j-1)
                nq.push_back((dac){l, j-1, optL, opt});
            if(j+1<=r)
                nq.push_back((dac){j+1, r, opt, optR});
        }
        q.swap(nq);
        nq.clear();
        ok=1;
    }
}

int main()
{
    scanf("%d%d", &N, &K);
    //K = 3;
    //N = 100000 / K;
    long long ans=0;
    for(int i=0; i<N; i++)
    {
        //A[i].first=rand()*10000+rand();
        //A[i].second=rand()*10000+rand();
        scanf("%d%d", &A[i].first, &A[i].second);
        if(A[i].first>A[i].second)
            swap(A[i].first, A[i].second);
        ans+=A[i].second-A[i].first;
    }
    sort(A, A+N, [](const pair<int, int>& a, const pair<int, int>& b) {
         if(a.second!=b.second)
            return a.second<b.second;
         return a.first<b.first;
    });
    for(int i=0; i<N; i++)
    {
        tie(L[i], R[i])=A[i];
        X[i*2]=L[i];
        X[i*2+1]=R[i];
    }
    sort(X, X+2*N);
    M=unique(X, X+2*N)-X;
    for(int i=0; i<N; i++)
    {
        int xc=lower_bound(X, X+M, L[i])-X;
        int yc=lower_bound(X, X+M, R[i])-X;
        B[i]=(query){L[i]+R[i], L[i], R[i], i, xc, yc};
    }
    sort(B, B+N);
    for(int i=1; i<=K; i++)
        dp[i].resize(M);
    long long SR=0, NR=0;
    for(int i=0, j=0; i<M; i++)
    {
        for(; j<N && R[j]<X[i]; j++)
            SR+=R[j], NR++;
        dp[1][i]=NR*X[i]-SR;
    }
    int flag=0;
    if(K==3 && M>=59000)
    {
        time_t tt=-clock();
        for(int i=2; i<=K-1; i++)
            solve(i, i-1, M-1, i-2, M-1);
        tt+=clock();
        if(tt>=CLOCKS_PER_SEC)
        {
            solve(3, 3*M/4, M-1, 2, M/2+50);
            flag=1;
        }
        else
            solve(3, 2, M-1, 1, M-1);
    }
    else
    {
        for(int i=2; i<=K; i++)
            solve(i, i-1, M-1, i-2, M-1);
    }
    sort(L, L+N, greater<int>());
    long long SL=0, NL=0;
    long long ans2=0x3f3f3f3f3f3f3f3fLL;
    if(flag)
    {
        for(int i=M-1, j=0; i>=3*M/4; i--)
        {
            for(; j<N && L[j]>X[i]; j++)
                SL+=L[j], NL++;
            ans2=min(ans2, dp[K][i]+SL-NL*X[i]);
        }
    }
    else
    {
        for(int i=M-1, j=0; i>=K-1; i--)
        {
            for(; j<N && L[j]>X[i]; j++)
                SL+=L[j], NL++;
            ans2=min(ans2, dp[K][i]+SL-NL*X[i]);
        }
    }
    cout << ans + ans2 * 2 << endl;
    return 0;
}









In Java :





import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;

public class Solution {

    static HD[] hdds;
    static Point[] points;
    static int n, k;
    static long[][] f;
    static long[] g;
    static long totalLength;
    static SegmentTree stHD;
    static PersistentSegmentTree pst;
    static Vertex[] rootsPST;
    static boolean dist;
    static long[] prefR;
    static long[] prefL;
    static Calc calc;
    
    static class HD {
        int left, right, length;
        Point pL, pR;
        long mid;
        int index;

        public HD(int left, int right) {
            boolean b = left > right;
            this.left = b ? right : left;
            this.right = b ? left : right;
            mid = (left + right);
            length = (this.right - this.left);
        }

        @Override
        public String toString() {
            return left + " " + right + " " + mid;
        }
    }

    static class Point {
        HD hd;
        int point;
        int sortIndex;

        public Point(HD hd, boolean isLeft) {
            this.hd = hd;
            point = isLeft ? hd.left : hd.right;
        }

        @Override
        public String toString() {
            return hd.left + " " + hd.right + " " + point;
        }
    }

    static interface Calc {
        public long getW(int start, int end);
    }

    static class Pair {
        int nL, nR;
        long sumL, sumR;

        public Pair(int nR, int nL, long sumR, long sumL) {
            this.nR = nR;
            this.nL = nL;
            this.sumR = sumR;
            this.sumL = sumL;
        }
    }

    static class Node {

        int[] valueSLe;
        int[] valueSRi;
        long[] sumParR;
        long[] sumParL;
        int size;
        int nL, nR;

        public Node() {
        }
    }

    static class SegmentTree {

        Node[] t;

        SegmentTree(int n) {
            t = new Node[4 * n];
            for (int i = 0; i < 4 * n; i++) {
                t[i] = new Node();
            }
        }

        void build(HD[] a, int v, int tl, int tr) {
            if (tl == tr) {
                t[v].valueSLe = new int[1];
                t[v].valueSRi = new int[1];
                t[v].valueSRi[0] = a[tl + 1].pR.point;
                t[v].valueSLe[0] = a[tl + 1].pL.point;
                t[v].size = 1;
                t[v].sumParL = new long[1];
                t[v].sumParL[0] = a[tl + 1].left;
                t[v].sumParR = new long[1];
                t[v].sumParR[0] = a[tl + 1].right;
            } else {
                int tm = (tl + tr) / 2;
                build(a, v * 2, tl, tm);
                build(a, v * 2 + 1, tm + 1, tr);
                t[v].size = t[v * 2].size + t[v * 2 + 1].size;
                merge(v);
            }
        }

        private void merge(int v) {
            t[v].valueSLe = new int[t[v].size];
            t[v].valueSRi = new int[t[v].size];
            t[v].sumParL = new long[t[v].size];
            t[v].sumParR = new long[t[v].size];
            mergeAInt(t[v].valueSLe, t[2 * v].valueSLe, t[2 * v + 1].valueSLe);
            mergeAInt(t[v].valueSRi, t[2 * v].valueSRi, t[2 * v + 1].valueSRi);
            updateCS(t[v]);
        }

        private void updateCS(Node node) {
            long[] sumParLu = new long[node.size];
            long[] sumParRu = new long[node.size];
            int[] valueL = node.valueSLe;
            int[] valueR = node.valueSRi;
            sumParLu[0] = valueL[0];
            sumParRu[0] = valueR[0];
            for (int i = 1; i < node.size; i++) {
                sumParLu[i] = sumParLu[i - 1] + valueL[i];
                sumParRu[i] = sumParRu[i - 1] + valueR[i];
            }
            node.sumParL = sumParLu;
            node.sumParR = sumParRu;
        }

        static long[] ZERO = new long[] { 0, 0, 0, 0 };

        void query(int v, int tl, int tr, int l, int r, int p, long[] fResults) {
            if (l > r || tr < l || tl > r) {
                return;
            }
            if (l == tl && tr == r) {
                Node node = t[v];
                int size = node.size;
                long[] sPL = node.sumParL;
                int limU = upperBound(node.valueSRi, p);
                int limL = lowerBound(node.valueSLe, p);
                int nR = limU == -1 ? 0 : limU + 1;
                long sumR = limU == -1 ? 0 : node.sumParR[limU];
                int nL = limL == size ? 0 : size - limL;
                long sumL = limL == size ? 0 : sPL[size - 1] - (limL == 0 ? 0 : sPL[limL - 1]);
                fResults[0] += nR;
                fResults[1] += nL;
                fResults[2] += sumR;
                fResults[3] += sumL;
                return;
            }
            int tm = (tl + tr) / 2;
            query(v * 2, tl, tm, l, Math.min(r, tm), p, fResults);
            query(v * 2 + 1, tm + 1, tr, Math.max(l, tm + 1), r, p, fResults);

        }
    }
    
    static class Vertex {
        Vertex l, r;
        int count;

        Vertex(int val) {
            count = val;
        }

        Vertex(Vertex l, Vertex r) {
            this.l = l;
            this.r = r;
            if (l != null)
                count += l.count;
            if (r != null)
                count += r.count;
        }
    };
    
    static class PersistentSegmentTree {
        
        Vertex build(int tl, int tr) {
            if (tl == tr) {
                return new Vertex(0);
            }
            int tm = (tl + tr) / 2;
            return new Vertex(build(tl, tm), build(tm + 1, tr));
        }

        Vertex update(Vertex v, int tl, int tr, int pos) {
            if (tl == tr) {
                return new Vertex(1);
            }
            int tm = (tl + tr) / 2;
            if (pos <= tm) {
                return new Vertex(update(v.l, tl, tm, pos), v.r);
            } else {
                return new Vertex(v.l, update(v.r, tm + 1, tr, pos));
            }
        }
        
        int query(Vertex v1, Vertex v2, int b, int e, int x) {
            if (b == e) {
                return b;
            }
            int oo = v1.l.count - v2.l.count;
            int mid = (b + e) / 2;
            if (oo >= x) {
                return query(v1.l, v2.l, b, mid, x);
            } else {
                return query(v1.r, v2.r, mid + 1, e, x - oo);
            }
        }
    }

    static void mergeAInt(int[] values, int[] values1, int[] values2) {
        int n1 = values1.length;
        int n2 = values2.length;
        int i = 0, j = 0, k = 0;
        while (i < n1 && j < n2) {
            int p1 = values1[i];
            int p2 = values2[j];
            if (p1 <= p2) {
                i++;
                values[k++] = p1;
            } else {
                j++;
                values[k++] = p2;
            }
        }
        while (i < n1) {
            values[k++] = values1[i++];
        }
        while (j < n2) {
            values[k++] = values2[j++];
        }
    }

    static private int upperBound(int[] arr, int endV) {
        int left = 0;
        int right = arr.length - 1;
        while (left <= right) {
            int mid = (right + left) / 2;
            if (arr[mid] <= endV) {
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return right;
    }

    static private int lowerBound(int[] arr, int startV) {
        int left = 0;
        int al = arr.length;
        int right = al - 1;
        while (left <= right) {
            int mid = (right + left) / 2;
            if (arr[mid] >= startV) {
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }

    private static void initW() {
        f = new long[k + 1][n + 1];
        g = new long[n + 2];        
        Arrays.sort(hdds, 1, n + 1, new Comparator<HD>() {
            @Override
            public int compare(HD o1, HD o2) {
                long d = o1.mid - o2.mid;
                if (d != 0) {
                    return d < 0 ? -1 : 1;
                } else {
                    return o1.left - o2.left;
                }
            }
        });
        SASD sd = new SASD();
        dist = true;
        for (int j = 1; j <= n; j++) {
            sd.add(hdds[j]);
            f[1][j] = sd.ans;
            if (j < n && hdds[j].right > hdds[j + 1].left) {
                dist = false;
            }
        }
        if (!dist) {
            calc = new Calc() {
                @Override
                public long getW(int start, int end) {
                    return getW1(start, end);                    
                }
            };
            Point[] pointsToSort = new Point[2 * n];
            for (int j = 1; j <= n; j++) {
                HD hd = hdds[j];
                hd.index = j;
                Point p = new Point(hd, true);
                hd.pL = p;
                pointsToSort[2 * j - 2] = p;
                p = new Point(hd, false);
                hd.pR = p;
                pointsToSort[2 * j - 1] = p;
            }            
            Arrays.sort(pointsToSort, 0, 2 * n, new Comparator<Point>() {
                @Override
                public int compare(Point o1, Point o2) {
                    long d = o1.point - o2.point;
                    if (d != 0) {
                        return d < 0 ? -1 : 1;
                    } else {
                        return o1.hd.index - o2.hd.index;
                    }
                }
            });
            points = new Point[2 * n];
            for (int j = 0; j < 2*n; j++) {
                Point p = pointsToSort[j]; 
                p.sortIndex = j;
                points[j] = p;
            }            
            stHD = new SegmentTree(n);
            stHD.build(hdds, 1, 0, n - 1);
            rootsPST = new Vertex[n+1];
            pst = new PersistentSegmentTree();
            rootsPST[0] = pst.build(0, 2*n-1);
            for (int j = 1; j <= n; j++) {
                Vertex root = pst.update(rootsPST[j-1], 0, 2*n-1, hdds[j].pL.sortIndex);
                rootsPST[j] = pst.update(root, 0, 2*n-1, hdds[j].pR.sortIndex);
            }            
        } else {
            calc = new Calc() {
                @Override
                public long getW(int start, int end) {
                    return getW2(start, end);
                }
            };
            prefL = new long[n + 1];
            prefR = new long[n + 1];
            for (int j = 1; j <= n; j++) {
                prefR[j] = prefR[j - 1] + hdds[j].right;
                prefL[j] = prefL[j - 1] + hdds[j].left;
            }
        }
    }


    static Map<Integer, Long>[] memo;
    private static long getW1(int jStart, int jEnd) {
        Long ans = memo[jEnd].get(jStart);
        if (ans != null) {
            return ans;
        }
        int mid = pst.query(rootsPST[jEnd], rootsPST[jStart-1], 0, 2 * n - 1, jEnd-jStart+1);
        int p = points[mid].point;
        long[] pair = new long[4];
        stHD.query(1, 0, n - 1, jStart - 1, jEnd - 1, p, pair);
        ans = p * pair[0] - pair[2] + pair[3] - p * pair[1];
        memo[jEnd].put(jStart, ans);
        return ans;
    }

    private static long getW2(int jStart, int jEnd) {
        int mid = (jStart + jEnd) / 2;
        long p = hdds[mid].right;
        long ans = p * (mid - jStart + 1) - (prefR[mid] - prefR[jStart - 1]) + (prefL[jEnd] - prefL[mid])
                - p * (jEnd - mid);
        return ans;
    }

    static int INF = 2000000000;

    static class SASD {
        int mid1 = -INF;
        int mid2 = INF;
        long ans;

        private PriorityQueue<Integer> r = new PriorityQueue<Integer>();

        void add(HD hd) {
            int nLow = hd.left;
            int nHigh = hd.right;
            if (nLow >= mid2) {
                ans += (nLow - mid2);
                r.remove();
                r.add(nLow);
                r.add(nHigh);
                mid1 = mid2;
                mid2 = r.peek();
            } else if (nLow < mid1) {
                r.add(nHigh);
            } else {
                r.add(nHigh);
                mid1 = nLow;
                if (mid2 == INF) {
                    mid2 = nHigh;
                }
                if (nHigh <= mid2) {
                    mid2 = nHigh;
                }
            }
        }
    }

    static class SADS {
        int mid1 = -INF;
        int mid2 = INF;
        long ans;

        private PriorityQueue<Integer> l = new PriorityQueue<Integer>(n, new Comparator<Integer>() {

            @Override
            public int compare(Integer o1, Integer o2) {
                return o2 - o1;
            }
        });

        void add(HD hd) {
            int nLow = hd.left;
            int nHigh = hd.right;
            if (nHigh <= mid1) {
                ans += (mid1 - nHigh);
                l.remove();
                l.add(nLow);
                l.add(nHigh);
                mid2 = mid1;
                mid1 = l.peek();
            } else if (nHigh > mid2) {
                l.add(nLow);
                if (nLow >= mid2) {
                    ans += (nHigh - mid2);
                }
            } else {
                l.add(nLow);
                mid2 = nHigh;
                if (mid1 == -INF) {
                    mid1 = nLow;
                }
                if (nLow >= mid1) {
                    mid1 = nLow;
                }
            }
        }
    }

    //static int c = 0;
    //static long time = 0;
    static long hardDrive() {
        if (k == n) {
            return totalLength;
        }
        initW();
        if (f[1][n] == 0) {
            return totalLength;
        }
        int STEEP = 10;
        SADS ds = new SADS();
        long ans = f[1][n];
        int minG = 1;
        for (int s = n; s >= 1; s--) {
            ds.add(hdds[s]);
            g[s] = ds.ans;
            long pAns = f[1][s - 1] + ds.ans;
            if (pAns < ans) {
                ans = pAns;
                minG = s;
            }
        }
        f[2][n] = ans;
        if (k == 2) {
            return 2 * ans + totalLength;
        }
        long fRMIN = 0;
        int[] minAJ = new int[n + 1];
        Arrays.fill(minAJ, 1);
        for (int i = 2; i < k; i++) {
            if (i == k - 1) {
                ans = f[i - 1][n];
                int lower = minG;
                int higher = n;
                if (higher > lower + STEEP) {
                    Pair p = search(calc, i, n, lower, higher, STEEP);
                    ans = p.sumR;
                    lower = p.nR;
                    higher = p.nL;
                }
                for (int s = lower; s <= higher; s++) {
                    long pAns = f[i - 1][s - 1] + g[s];
                    if (pAns < ans) {
                        ans = pAns;
                        minG = s;
                    }
                }
                f[i][n] = ans;
                fRMIN = ans;
                int minG2 = n;
                minG = minG > 1 ? minG - 1 : 1;
                int fSteep = 100;
                for (int j = minG; j <= minG2; j++) {
                    ans = f[i - 1][j];
                    lower = 1;
                    higher = j;
                    if (higher > lower + STEEP) {
                        Pair p = search(calc, i, j, lower, higher, STEEP);
                        ans = p.sumR;
                        lower = p.nR;
                        higher = p.nL;
                    }
                    if (f[i - 1][lower - 1] != f[i - 1][higher - 1]) {
                        for (int s = lower; s <= higher; s++) {
                            long pAns = f[i - 1][s - 1] + calc.getW(s, j);
                            if (pAns < ans) {
                                ans = pAns;
                            }
                        }
                    } else {                        
                        ans = f[i - 1][higher - 1] + calc.getW(higher, j);
                    }    
                    f[i][j] = ans;
                    long pFA = ans + g[j + 1];
                    if (pFA < fRMIN) {
                        fRMIN = pFA;
                    }
                    if (ans >= fRMIN) {
                        break;
                    }
                    if (j + fSteep <= minG2 && ans + g[j + fSteep] >= fRMIN) {
                        j = j + fSteep;
                    }
                    while (j < minG2 &&  g[j+1] == g[j+2]) {
                        j++;
                    }
                }
            } else {
                for (int j = i + 1; j < n; j++) {
                    ans = f[i - 1][j];
                    int nMin = minAJ[j];
                    int lower = nMin;
                    int higher = j;
                    if (higher > lower + STEEP) {
                        Pair p = search(calc, i, j, lower, higher, STEEP);
                        ans = p.sumR;
                        lower = p.nR;
                        higher = p.nL;                        
                    }
                    if (f[i - 1][lower - 1] != f[i - 1][higher - 1]) {
                        for (int s = lower; s <= higher; s++) {
                            long pAns = f[i - 1][s - 1] + calc.getW(s, j);
                            if (pAns < ans) {
                                ans = pAns;
                                nMin = s;
                            }
                        }
                    } else {                        
                        ans = f[i - 1][higher - 1] + calc.getW(higher, j);         
                    }
                    f[i][j] = ans;
                    minAJ[j] = nMin;
                }
            }
        }
        return 2 * fRMIN + totalLength;
    }

    static Pair search(Calc calc, int i, int j, int lower, int higher, int steep) {
        long ans = 0;
        while (higher > lower + steep) {
            int mid = (lower + higher) / 2;
            int mid1 = (lower + mid) / 2;
            mid1 = mid1 % 2 == 0 ? mid1 : mid1 + 1;
            int mid2 = (higher + mid) / 2;
            mid2 = mid2 % 2 == 0 ? mid2 : mid2 - 1;
            long res1 = f[i - 1][mid1 - 1] + calc.getW(mid1, j);
            long res2 = f[i - 1][mid2 - 1] + calc.getW(mid2, j);
            if (res1 < res2) {
                higher = mid2 - 1;
                ans = res1;
            } else if (res1 >= res2) {
                lower = mid1 + 1;
                ans = res2;
            }
        }
        return new Pair(lower, higher, ans, 0);
    }

    public static void main(String[] args) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in), 2 * 4096 * 4096);
        String[] nk = reader.readLine().trim().split(" ");
        n = Integer.parseInt(nk[0]);
        k = Integer.parseInt(nk[1]);
        hdds = new HD[n + 1];
        totalLength = 0;
        for (int hddsRowItr = 1; hddsRowItr <= n; hddsRowItr++) {
            String[] hddsRowItems = reader.readLine().trim().split(" ");
            HD hd = new HD(Integer.parseInt(hddsRowItems[0]), Integer.parseInt(hddsRowItems[1]));
            hdds[hddsRowItr] = hd;
            totalLength += hd.length;
        }
        memo = new Map[n+1];
        for (int j = 1; j <= n; j++) {
            memo[j] = new HashMap<Integer, Long>();
        }
        dist = true;
        long result = hardDrive();
        bufferedWriter.write(String.valueOf(result));
        bufferedWriter.newLine();
        bufferedWriter.close();
        reader.close();
    }
}
                        








View More Similar Problems

Tree: Postorder Traversal

Complete the postorder function in the editor below. It received 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's postorder traversal as a single line of space-separated values. Input Format Our test code passes the root node of a binary tree to the postorder function. Constraints 1 <= Nodes in the tree <= 500 Output Format Print the

View Solution →

Tree: Inorder Traversal

In this challenge, you are required to implement inorder traversal of a tree. Complete the inorder function in your editor below, which has 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's inorder traversal as a single line of space-separated values. Input Format Our hidden tester code passes the root node of a binary tree to your $inOrder* func

View Solution →

Tree: Height of a Binary Tree

The height of a binary tree is the number of edges between the tree's root and its furthest leaf. For example, the following binary tree is of height : image Function Description Complete the getHeight or height function in the editor. It must return the height of a binary tree as an integer. getHeight or height has the following parameter(s): root: a reference to the root of a binary

View Solution →

Tree : Top View

Given a pointer to the root of a binary tree, print the top view of the binary tree. The tree as seen from the top the nodes, is called the top view of the tree. For example : 1 \ 2 \ 5 / \ 3 6 \ 4 Top View : 1 -> 2 -> 5 -> 6 Complete the function topView and print the resulting values on a single line separated by space.

View Solution →

Tree: Level Order Traversal

Given a pointer to the root of a binary tree, you need to print the level order traversal of this tree. In level-order traversal, nodes are visited level by level from left to right. Complete the function levelOrder and print the values in a single line separated by a space. For example: 1 \ 2 \ 5 / \ 3 6 \ 4 F

View Solution →

Binary Search Tree : Insertion

You are given a pointer to the root of a binary search tree and values to be inserted into the tree. Insert the values into their appropriate position in the binary search tree and return the root of the updated binary tree. You just have to complete the function. Input Format You are given a function, Node * insert (Node * root ,int data) { } Constraints No. of nodes in the tree <

View Solution →