Hard Disk Drives


Problem Statement :


There are n pairs of hard disk drives (HDDs) in a cluster. Each HDD is located at an integer coordinate on an infinite straight line, and each pair consists of one primary HDD and one backup HDD.

Next, you want to place k computers at integer coordinates on the same infinite straight line. Each pair of HDDs must then be connected to a single computer via wires, but a computer can have any number (even zero) of HDDs connected to it. The length of a wire connecting a single HDD to a computer is the absolute value of the distance between their respective coordinates on the infinite line. We consider the total length of wire used to connect all the HDDs to computers to be the sum of the lengths of all the wires used to connect HDDs to computers. Note that both the primary and secondary HDDs in a pair must connect to the same computer.

Given the locations of n pairs (i.e., primary and backup) of HDDs and the value of k, place all k computers in such a way that the total length of wire needed to connect each pair of HDDs to computers is minimal. Then print the total length on a new line.

Input Format

The first line contains two space-separated integers denoting the respective values of n (the number of pairs of HDDs) and k (the number of computers).
Each line i of the n subsequent lines contains two space-separated integers describing the respective values of ai (coordinate of the primary HDD) and bi (coordinate of the backup HDD) for a pair of HDDs.

Constraints

2 <= k <= n  <= 10^5
4 <= k*n <= 10^5
-10^9 <= ai,bi  <= 10^9
Output Format

Print a single integer denoting the minimum total length of wire needed to connect all the pairs of HDDs to computers.



Solution :



title-img


                            Solution in C :

In C++ :





#include <bits/stdc++.h>

using namespace std;

struct node
{
    long long sa, sb;
    int sc;
    node()
    {
        sa=sb=sc=0;
    }
    node(long long a, long long b, int c)
    {
        sa=a, sb=b, sc=c;
    }
    node& operator+= (const node& other)
    {
        sa+=other.sa;
        sb+=other.sb;
        sc+=other.sc;
        return *this;
    }
    node& operator-= (const node& other)
    {
        sa-=other.sa;
        sb-=other.sb;
        sc-=other.sc;
        return *this;
    }
    bool operator== (const node& other) const
    {
        return sa==other.sa && sb==other.sb && sc==other.sc;
    }
};

struct SplayNode
{
    SplayNode *ch[2], *p;
    node n, o;
    int k;
    void init(long long k, node o)
    {
        ch[0]=ch[1]=p=nullptr;
        this->k=k;
        this->n=this->o=o;
    }
} pool[2000000];

int npool;

void maintain(SplayNode *n)
{
    n->n=n->o;
    for(int i=0; i<2; i++) if(n->ch[i])
    {
        n->ch[i]->p=n;
        n->n+=n->ch[i]->n;
    }
}

int child(SplayNode *n)
{
    if(n->p)
    {
        if(n->p->ch[0]==n)
            return 0;
        if(n->p->ch[1]==n)
            return 1;
    }
    return -1;
}

void rotate_up(SplayNode *n)
{
    SplayNode *p=n->p;
    int d=child(n);
    int pd=child(p);
    assert(d!=-1);
    p->ch[d]=n->ch[d^1];
    n->ch[d^1]=p;
    n->p=p->p;
    maintain(p);
    maintain(n);
    if(pd!=-1)
    {
        n->p->ch[pd]=n;
        maintain(n->p);
    }
}

void splay(SplayNode *n)
{
    while(child(n)!=-1)
    {
        if(child(n->p)==-1)
            rotate_up(n);
        else
        {
            if(child(n)!=child(n->p))
                rotate_up(n), rotate_up(n);
            else
                rotate_up(n->p), rotate_up(n);
        }
    }
}

void insert(SplayNode *n, SplayNode *t)
{
    if(t->k<n->k)
    {
        if(!n->ch[0])
            n->ch[0]=t;
        else
            insert(n->ch[0], t);
    }
    else
    {
        if(!n->ch[1])
            n->ch[1]=t;
        else
            insert(n->ch[1], t);
    }
    maintain(n);
}
SplayNode *lastTouch;
node ask(SplayNode *n, int k)
{
    if(!n)
        return node();
    lastTouch=n;
    if(k<n->k)
        return ask(n->ch[0], k);
    node ret=n->o;
    if(n->ch[0])
        ret+=n->ch[0]->n;
    ret+=ask(n->ch[1], k);
    return ret;
}

int N, M, K;
pair<int, int> A[100000];
int X[200000];
int L[100000];
int R[100000];
int fsize[200001];
int ok;
vector<long long> dp[351];
SplayNode* bit[200001];

struct dac
{
    int l, r, optL, optR;
    tuple<int, int, int, int> mt() const
    {
        return make_tuple(l, r, optL, optR);
    }
};

struct query
{
    int m, x, y, i, xc, yc;
    bool operator< (const query& other) const
    {
        return m<other.m;
    }
} B[100000];

const int MAGIC=1500;
long long val[200000];
long long val2[200000];
vector<node> blt[200001];

void add_point(int x, int y, node n)
{
    x=M-1-x;
    for(int i=x+1; i<=M; i+=i&-i)
    {
        if(ok && fsize[i]>=MAGIC)
        {
            for(int j=y+1; j<=M; j+=j&-j)
                blt[i][j]+=n;
            continue;
        }
        if(!ok)
            fsize[i]++;
        SplayNode *t=&pool[npool++];
        if(npool==2000000)
            npool=0;
        t->init(y+1, n);
        if(bit[i])
        {
            insert(bit[i], t);
            splay(t);
        }
        bit[i]=t;
    }
}

void era_point(int x, int y)
{
    x=M-1-x;
    for(int i=x+1; i<=M; i+=i&-i)
    {
        if(ok && fsize[i]>=MAGIC)
        {
            for(int j=y+1; j<=M; j+=j&-j)
                blt[i][j]=node();
            continue;
        }
    }
}

node ask_point(int x, int y)
{
    x=M-1-x;
    node n;
    for(int i=x+1; i>0; i-=i&-i)
    {
        if(ok && fsize[i]>=MAGIC)
        {
            for(int j=y+1; j>0; j-=j&-j)
                n+=blt[i][j];
            continue;
        }
        if(bit[i])
        {
            n+=ask(bit[i], y+1);
            bit[i]=lastTouch;
            splay(bit[i]);
        }
    }
    return n;
}

void solve(int i, int l, int r, int optL, int optR)
{
    if(l>r)
        return;
    vector<dac> q, nq;
    q.push_back((dac){l, r, optL, optR});
    while(!q.empty())
    {
        for(int j=1; j<=M; j++)
        {
            bit[j]=nullptr;
            if(ok && fsize[j]>=MAGIC)
            {
                if((int)blt[j].size()!=M+1)
                    blt[j].resize(M+1);
                fill(blt[j].begin(), blt[j].end(), node());
            }
        }
        vector<query> queries;
        int nqueries=0;
        for(auto& it: q)
        {
            tie(l, r, optL, optR)=it.mt();
            int j=(l+r)/2;
            dp[i][j]=0x3f3f3f3f3f3f3f3fLL;
            for(int k=min(optR, j-1); k>=optL; k--)
            {
                val[nqueries]=dp[i-1][k];
                queries.push_back((query){X[k]+X[j], X[k], X[j], nqueries++, k, j});
            }
        }
        sort(queries.begin(), queries.end());
        int l=0;
        for(int k=0; k<(int)queries.size(); k++)
        {
            for(; l<N && B[l].m<queries[k].m; l++)
                add_point(B[l].xc, B[l].yc, (node){B[l].m, B[l].y, 1});
            node n=ask_point(queries[k].xc, queries[k].yc);
            val[queries[k].i]+=n.sa;
            val2[queries[k].i]=n.sc;
        }
        for(; l<N; l++)
            add_point(B[l].xc, B[l].yc, (node){B[l].m, B[l].y, 1});
        for(int k=0; k<(int)queries.size(); k++)
        {
            node n=ask_point(queries[k].xc, queries[k].yc);
            val[queries[k].i]-=n.sb;
            val[queries[k].i]-=1LL*n.sc*queries[k].x;
            val[queries[k].i]+=queries[k].m*(n.sc-val2[queries[k].i]);
        }
        //for(l=0; l<N; l++)
            //era_point(B[l].xc, B[l].yc);
        nqueries=0;
        for(auto& it: q)
        {
            tie(l, r, optL, optR)=it.mt();
            int j=(l+r)/2;
            dp[i][j]=0x3f3f3f3f3f3f3f3fLL;
            int opt=0;
            for(int k=min(optR, j-1); k>=optL; k--)
            {
                if(val[nqueries]<dp[i][j])
                {
                    opt=k;
                    dp[i][j]=val[nqueries];
                }
                nqueries++;
            }
            if(l<=j-1)
                nq.push_back((dac){l, j-1, optL, opt});
            if(j+1<=r)
                nq.push_back((dac){j+1, r, opt, optR});
        }
        q.swap(nq);
        nq.clear();
        ok=1;
    }
}

int main()
{
    scanf("%d%d", &N, &K);
    //K = 3;
    //N = 100000 / K;
    long long ans=0;
    for(int i=0; i<N; i++)
    {
        //A[i].first=rand()*10000+rand();
        //A[i].second=rand()*10000+rand();
        scanf("%d%d", &A[i].first, &A[i].second);
        if(A[i].first>A[i].second)
            swap(A[i].first, A[i].second);
        ans+=A[i].second-A[i].first;
    }
    sort(A, A+N, [](const pair<int, int>& a, const pair<int, int>& b) {
         if(a.second!=b.second)
            return a.second<b.second;
         return a.first<b.first;
    });
    for(int i=0; i<N; i++)
    {
        tie(L[i], R[i])=A[i];
        X[i*2]=L[i];
        X[i*2+1]=R[i];
    }
    sort(X, X+2*N);
    M=unique(X, X+2*N)-X;
    for(int i=0; i<N; i++)
    {
        int xc=lower_bound(X, X+M, L[i])-X;
        int yc=lower_bound(X, X+M, R[i])-X;
        B[i]=(query){L[i]+R[i], L[i], R[i], i, xc, yc};
    }
    sort(B, B+N);
    for(int i=1; i<=K; i++)
        dp[i].resize(M);
    long long SR=0, NR=0;
    for(int i=0, j=0; i<M; i++)
    {
        for(; j<N && R[j]<X[i]; j++)
            SR+=R[j], NR++;
        dp[1][i]=NR*X[i]-SR;
    }
    int flag=0;
    if(K==3 && M>=59000)
    {
        time_t tt=-clock();
        for(int i=2; i<=K-1; i++)
            solve(i, i-1, M-1, i-2, M-1);
        tt+=clock();
        if(tt>=CLOCKS_PER_SEC)
        {
            solve(3, 3*M/4, M-1, 2, M/2+50);
            flag=1;
        }
        else
            solve(3, 2, M-1, 1, M-1);
    }
    else
    {
        for(int i=2; i<=K; i++)
            solve(i, i-1, M-1, i-2, M-1);
    }
    sort(L, L+N, greater<int>());
    long long SL=0, NL=0;
    long long ans2=0x3f3f3f3f3f3f3f3fLL;
    if(flag)
    {
        for(int i=M-1, j=0; i>=3*M/4; i--)
        {
            for(; j<N && L[j]>X[i]; j++)
                SL+=L[j], NL++;
            ans2=min(ans2, dp[K][i]+SL-NL*X[i]);
        }
    }
    else
    {
        for(int i=M-1, j=0; i>=K-1; i--)
        {
            for(; j<N && L[j]>X[i]; j++)
                SL+=L[j], NL++;
            ans2=min(ans2, dp[K][i]+SL-NL*X[i]);
        }
    }
    cout << ans + ans2 * 2 << endl;
    return 0;
}









In Java :





import java.io.*;
import java.math.*;
import java.text.*;
import java.util.*;
import java.util.regex.*;

public class Solution {

    static HD[] hdds;
    static Point[] points;
    static int n, k;
    static long[][] f;
    static long[] g;
    static long totalLength;
    static SegmentTree stHD;
    static PersistentSegmentTree pst;
    static Vertex[] rootsPST;
    static boolean dist;
    static long[] prefR;
    static long[] prefL;
    static Calc calc;
    
    static class HD {
        int left, right, length;
        Point pL, pR;
        long mid;
        int index;

        public HD(int left, int right) {
            boolean b = left > right;
            this.left = b ? right : left;
            this.right = b ? left : right;
            mid = (left + right);
            length = (this.right - this.left);
        }

        @Override
        public String toString() {
            return left + " " + right + " " + mid;
        }
    }

    static class Point {
        HD hd;
        int point;
        int sortIndex;

        public Point(HD hd, boolean isLeft) {
            this.hd = hd;
            point = isLeft ? hd.left : hd.right;
        }

        @Override
        public String toString() {
            return hd.left + " " + hd.right + " " + point;
        }
    }

    static interface Calc {
        public long getW(int start, int end);
    }

    static class Pair {
        int nL, nR;
        long sumL, sumR;

        public Pair(int nR, int nL, long sumR, long sumL) {
            this.nR = nR;
            this.nL = nL;
            this.sumR = sumR;
            this.sumL = sumL;
        }
    }

    static class Node {

        int[] valueSLe;
        int[] valueSRi;
        long[] sumParR;
        long[] sumParL;
        int size;
        int nL, nR;

        public Node() {
        }
    }

    static class SegmentTree {

        Node[] t;

        SegmentTree(int n) {
            t = new Node[4 * n];
            for (int i = 0; i < 4 * n; i++) {
                t[i] = new Node();
            }
        }

        void build(HD[] a, int v, int tl, int tr) {
            if (tl == tr) {
                t[v].valueSLe = new int[1];
                t[v].valueSRi = new int[1];
                t[v].valueSRi[0] = a[tl + 1].pR.point;
                t[v].valueSLe[0] = a[tl + 1].pL.point;
                t[v].size = 1;
                t[v].sumParL = new long[1];
                t[v].sumParL[0] = a[tl + 1].left;
                t[v].sumParR = new long[1];
                t[v].sumParR[0] = a[tl + 1].right;
            } else {
                int tm = (tl + tr) / 2;
                build(a, v * 2, tl, tm);
                build(a, v * 2 + 1, tm + 1, tr);
                t[v].size = t[v * 2].size + t[v * 2 + 1].size;
                merge(v);
            }
        }

        private void merge(int v) {
            t[v].valueSLe = new int[t[v].size];
            t[v].valueSRi = new int[t[v].size];
            t[v].sumParL = new long[t[v].size];
            t[v].sumParR = new long[t[v].size];
            mergeAInt(t[v].valueSLe, t[2 * v].valueSLe, t[2 * v + 1].valueSLe);
            mergeAInt(t[v].valueSRi, t[2 * v].valueSRi, t[2 * v + 1].valueSRi);
            updateCS(t[v]);
        }

        private void updateCS(Node node) {
            long[] sumParLu = new long[node.size];
            long[] sumParRu = new long[node.size];
            int[] valueL = node.valueSLe;
            int[] valueR = node.valueSRi;
            sumParLu[0] = valueL[0];
            sumParRu[0] = valueR[0];
            for (int i = 1; i < node.size; i++) {
                sumParLu[i] = sumParLu[i - 1] + valueL[i];
                sumParRu[i] = sumParRu[i - 1] + valueR[i];
            }
            node.sumParL = sumParLu;
            node.sumParR = sumParRu;
        }

        static long[] ZERO = new long[] { 0, 0, 0, 0 };

        void query(int v, int tl, int tr, int l, int r, int p, long[] fResults) {
            if (l > r || tr < l || tl > r) {
                return;
            }
            if (l == tl && tr == r) {
                Node node = t[v];
                int size = node.size;
                long[] sPL = node.sumParL;
                int limU = upperBound(node.valueSRi, p);
                int limL = lowerBound(node.valueSLe, p);
                int nR = limU == -1 ? 0 : limU + 1;
                long sumR = limU == -1 ? 0 : node.sumParR[limU];
                int nL = limL == size ? 0 : size - limL;
                long sumL = limL == size ? 0 : sPL[size - 1] - (limL == 0 ? 0 : sPL[limL - 1]);
                fResults[0] += nR;
                fResults[1] += nL;
                fResults[2] += sumR;
                fResults[3] += sumL;
                return;
            }
            int tm = (tl + tr) / 2;
            query(v * 2, tl, tm, l, Math.min(r, tm), p, fResults);
            query(v * 2 + 1, tm + 1, tr, Math.max(l, tm + 1), r, p, fResults);

        }
    }
    
    static class Vertex {
        Vertex l, r;
        int count;

        Vertex(int val) {
            count = val;
        }

        Vertex(Vertex l, Vertex r) {
            this.l = l;
            this.r = r;
            if (l != null)
                count += l.count;
            if (r != null)
                count += r.count;
        }
    };
    
    static class PersistentSegmentTree {
        
        Vertex build(int tl, int tr) {
            if (tl == tr) {
                return new Vertex(0);
            }
            int tm = (tl + tr) / 2;
            return new Vertex(build(tl, tm), build(tm + 1, tr));
        }

        Vertex update(Vertex v, int tl, int tr, int pos) {
            if (tl == tr) {
                return new Vertex(1);
            }
            int tm = (tl + tr) / 2;
            if (pos <= tm) {
                return new Vertex(update(v.l, tl, tm, pos), v.r);
            } else {
                return new Vertex(v.l, update(v.r, tm + 1, tr, pos));
            }
        }
        
        int query(Vertex v1, Vertex v2, int b, int e, int x) {
            if (b == e) {
                return b;
            }
            int oo = v1.l.count - v2.l.count;
            int mid = (b + e) / 2;
            if (oo >= x) {
                return query(v1.l, v2.l, b, mid, x);
            } else {
                return query(v1.r, v2.r, mid + 1, e, x - oo);
            }
        }
    }

    static void mergeAInt(int[] values, int[] values1, int[] values2) {
        int n1 = values1.length;
        int n2 = values2.length;
        int i = 0, j = 0, k = 0;
        while (i < n1 && j < n2) {
            int p1 = values1[i];
            int p2 = values2[j];
            if (p1 <= p2) {
                i++;
                values[k++] = p1;
            } else {
                j++;
                values[k++] = p2;
            }
        }
        while (i < n1) {
            values[k++] = values1[i++];
        }
        while (j < n2) {
            values[k++] = values2[j++];
        }
    }

    static private int upperBound(int[] arr, int endV) {
        int left = 0;
        int right = arr.length - 1;
        while (left <= right) {
            int mid = (right + left) / 2;
            if (arr[mid] <= endV) {
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return right;
    }

    static private int lowerBound(int[] arr, int startV) {
        int left = 0;
        int al = arr.length;
        int right = al - 1;
        while (left <= right) {
            int mid = (right + left) / 2;
            if (arr[mid] >= startV) {
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }

    private static void initW() {
        f = new long[k + 1][n + 1];
        g = new long[n + 2];        
        Arrays.sort(hdds, 1, n + 1, new Comparator<HD>() {
            @Override
            public int compare(HD o1, HD o2) {
                long d = o1.mid - o2.mid;
                if (d != 0) {
                    return d < 0 ? -1 : 1;
                } else {
                    return o1.left - o2.left;
                }
            }
        });
        SASD sd = new SASD();
        dist = true;
        for (int j = 1; j <= n; j++) {
            sd.add(hdds[j]);
            f[1][j] = sd.ans;
            if (j < n && hdds[j].right > hdds[j + 1].left) {
                dist = false;
            }
        }
        if (!dist) {
            calc = new Calc() {
                @Override
                public long getW(int start, int end) {
                    return getW1(start, end);                    
                }
            };
            Point[] pointsToSort = new Point[2 * n];
            for (int j = 1; j <= n; j++) {
                HD hd = hdds[j];
                hd.index = j;
                Point p = new Point(hd, true);
                hd.pL = p;
                pointsToSort[2 * j - 2] = p;
                p = new Point(hd, false);
                hd.pR = p;
                pointsToSort[2 * j - 1] = p;
            }            
            Arrays.sort(pointsToSort, 0, 2 * n, new Comparator<Point>() {
                @Override
                public int compare(Point o1, Point o2) {
                    long d = o1.point - o2.point;
                    if (d != 0) {
                        return d < 0 ? -1 : 1;
                    } else {
                        return o1.hd.index - o2.hd.index;
                    }
                }
            });
            points = new Point[2 * n];
            for (int j = 0; j < 2*n; j++) {
                Point p = pointsToSort[j]; 
                p.sortIndex = j;
                points[j] = p;
            }            
            stHD = new SegmentTree(n);
            stHD.build(hdds, 1, 0, n - 1);
            rootsPST = new Vertex[n+1];
            pst = new PersistentSegmentTree();
            rootsPST[0] = pst.build(0, 2*n-1);
            for (int j = 1; j <= n; j++) {
                Vertex root = pst.update(rootsPST[j-1], 0, 2*n-1, hdds[j].pL.sortIndex);
                rootsPST[j] = pst.update(root, 0, 2*n-1, hdds[j].pR.sortIndex);
            }            
        } else {
            calc = new Calc() {
                @Override
                public long getW(int start, int end) {
                    return getW2(start, end);
                }
            };
            prefL = new long[n + 1];
            prefR = new long[n + 1];
            for (int j = 1; j <= n; j++) {
                prefR[j] = prefR[j - 1] + hdds[j].right;
                prefL[j] = prefL[j - 1] + hdds[j].left;
            }
        }
    }


    static Map<Integer, Long>[] memo;
    private static long getW1(int jStart, int jEnd) {
        Long ans = memo[jEnd].get(jStart);
        if (ans != null) {
            return ans;
        }
        int mid = pst.query(rootsPST[jEnd], rootsPST[jStart-1], 0, 2 * n - 1, jEnd-jStart+1);
        int p = points[mid].point;
        long[] pair = new long[4];
        stHD.query(1, 0, n - 1, jStart - 1, jEnd - 1, p, pair);
        ans = p * pair[0] - pair[2] + pair[3] - p * pair[1];
        memo[jEnd].put(jStart, ans);
        return ans;
    }

    private static long getW2(int jStart, int jEnd) {
        int mid = (jStart + jEnd) / 2;
        long p = hdds[mid].right;
        long ans = p * (mid - jStart + 1) - (prefR[mid] - prefR[jStart - 1]) + (prefL[jEnd] - prefL[mid])
                - p * (jEnd - mid);
        return ans;
    }

    static int INF = 2000000000;

    static class SASD {
        int mid1 = -INF;
        int mid2 = INF;
        long ans;

        private PriorityQueue<Integer> r = new PriorityQueue<Integer>();

        void add(HD hd) {
            int nLow = hd.left;
            int nHigh = hd.right;
            if (nLow >= mid2) {
                ans += (nLow - mid2);
                r.remove();
                r.add(nLow);
                r.add(nHigh);
                mid1 = mid2;
                mid2 = r.peek();
            } else if (nLow < mid1) {
                r.add(nHigh);
            } else {
                r.add(nHigh);
                mid1 = nLow;
                if (mid2 == INF) {
                    mid2 = nHigh;
                }
                if (nHigh <= mid2) {
                    mid2 = nHigh;
                }
            }
        }
    }

    static class SADS {
        int mid1 = -INF;
        int mid2 = INF;
        long ans;

        private PriorityQueue<Integer> l = new PriorityQueue<Integer>(n, new Comparator<Integer>() {

            @Override
            public int compare(Integer o1, Integer o2) {
                return o2 - o1;
            }
        });

        void add(HD hd) {
            int nLow = hd.left;
            int nHigh = hd.right;
            if (nHigh <= mid1) {
                ans += (mid1 - nHigh);
                l.remove();
                l.add(nLow);
                l.add(nHigh);
                mid2 = mid1;
                mid1 = l.peek();
            } else if (nHigh > mid2) {
                l.add(nLow);
                if (nLow >= mid2) {
                    ans += (nHigh - mid2);
                }
            } else {
                l.add(nLow);
                mid2 = nHigh;
                if (mid1 == -INF) {
                    mid1 = nLow;
                }
                if (nLow >= mid1) {
                    mid1 = nLow;
                }
            }
        }
    }

    //static int c = 0;
    //static long time = 0;
    static long hardDrive() {
        if (k == n) {
            return totalLength;
        }
        initW();
        if (f[1][n] == 0) {
            return totalLength;
        }
        int STEEP = 10;
        SADS ds = new SADS();
        long ans = f[1][n];
        int minG = 1;
        for (int s = n; s >= 1; s--) {
            ds.add(hdds[s]);
            g[s] = ds.ans;
            long pAns = f[1][s - 1] + ds.ans;
            if (pAns < ans) {
                ans = pAns;
                minG = s;
            }
        }
        f[2][n] = ans;
        if (k == 2) {
            return 2 * ans + totalLength;
        }
        long fRMIN = 0;
        int[] minAJ = new int[n + 1];
        Arrays.fill(minAJ, 1);
        for (int i = 2; i < k; i++) {
            if (i == k - 1) {
                ans = f[i - 1][n];
                int lower = minG;
                int higher = n;
                if (higher > lower + STEEP) {
                    Pair p = search(calc, i, n, lower, higher, STEEP);
                    ans = p.sumR;
                    lower = p.nR;
                    higher = p.nL;
                }
                for (int s = lower; s <= higher; s++) {
                    long pAns = f[i - 1][s - 1] + g[s];
                    if (pAns < ans) {
                        ans = pAns;
                        minG = s;
                    }
                }
                f[i][n] = ans;
                fRMIN = ans;
                int minG2 = n;
                minG = minG > 1 ? minG - 1 : 1;
                int fSteep = 100;
                for (int j = minG; j <= minG2; j++) {
                    ans = f[i - 1][j];
                    lower = 1;
                    higher = j;
                    if (higher > lower + STEEP) {
                        Pair p = search(calc, i, j, lower, higher, STEEP);
                        ans = p.sumR;
                        lower = p.nR;
                        higher = p.nL;
                    }
                    if (f[i - 1][lower - 1] != f[i - 1][higher - 1]) {
                        for (int s = lower; s <= higher; s++) {
                            long pAns = f[i - 1][s - 1] + calc.getW(s, j);
                            if (pAns < ans) {
                                ans = pAns;
                            }
                        }
                    } else {                        
                        ans = f[i - 1][higher - 1] + calc.getW(higher, j);
                    }    
                    f[i][j] = ans;
                    long pFA = ans + g[j + 1];
                    if (pFA < fRMIN) {
                        fRMIN = pFA;
                    }
                    if (ans >= fRMIN) {
                        break;
                    }
                    if (j + fSteep <= minG2 && ans + g[j + fSteep] >= fRMIN) {
                        j = j + fSteep;
                    }
                    while (j < minG2 &&  g[j+1] == g[j+2]) {
                        j++;
                    }
                }
            } else {
                for (int j = i + 1; j < n; j++) {
                    ans = f[i - 1][j];
                    int nMin = minAJ[j];
                    int lower = nMin;
                    int higher = j;
                    if (higher > lower + STEEP) {
                        Pair p = search(calc, i, j, lower, higher, STEEP);
                        ans = p.sumR;
                        lower = p.nR;
                        higher = p.nL;                        
                    }
                    if (f[i - 1][lower - 1] != f[i - 1][higher - 1]) {
                        for (int s = lower; s <= higher; s++) {
                            long pAns = f[i - 1][s - 1] + calc.getW(s, j);
                            if (pAns < ans) {
                                ans = pAns;
                                nMin = s;
                            }
                        }
                    } else {                        
                        ans = f[i - 1][higher - 1] + calc.getW(higher, j);         
                    }
                    f[i][j] = ans;
                    minAJ[j] = nMin;
                }
            }
        }
        return 2 * fRMIN + totalLength;
    }

    static Pair search(Calc calc, int i, int j, int lower, int higher, int steep) {
        long ans = 0;
        while (higher > lower + steep) {
            int mid = (lower + higher) / 2;
            int mid1 = (lower + mid) / 2;
            mid1 = mid1 % 2 == 0 ? mid1 : mid1 + 1;
            int mid2 = (higher + mid) / 2;
            mid2 = mid2 % 2 == 0 ? mid2 : mid2 - 1;
            long res1 = f[i - 1][mid1 - 1] + calc.getW(mid1, j);
            long res2 = f[i - 1][mid2 - 1] + calc.getW(mid2, j);
            if (res1 < res2) {
                higher = mid2 - 1;
                ans = res1;
            } else if (res1 >= res2) {
                lower = mid1 + 1;
                ans = res2;
            }
        }
        return new Pair(lower, higher, ans, 0);
    }

    public static void main(String[] args) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in), 2 * 4096 * 4096);
        String[] nk = reader.readLine().trim().split(" ");
        n = Integer.parseInt(nk[0]);
        k = Integer.parseInt(nk[1]);
        hdds = new HD[n + 1];
        totalLength = 0;
        for (int hddsRowItr = 1; hddsRowItr <= n; hddsRowItr++) {
            String[] hddsRowItems = reader.readLine().trim().split(" ");
            HD hd = new HD(Integer.parseInt(hddsRowItems[0]), Integer.parseInt(hddsRowItems[1]));
            hdds[hddsRowItr] = hd;
            totalLength += hd.length;
        }
        memo = new Map[n+1];
        for (int j = 1; j <= n; j++) {
            memo[j] = new HashMap<Integer, Long>();
        }
        dist = true;
        long result = hardDrive();
        bufferedWriter.write(String.valueOf(result));
        bufferedWriter.newLine();
        bufferedWriter.close();
        reader.close();
    }
}
                        








View More Similar Problems

Swap Nodes [Algo]

A binary tree is a tree which is characterized by one of the following properties: It can be empty (null). It contains a root node only. It contains a root node with a left subtree, a right subtree, or both. These subtrees are also binary trees. In-order traversal is performed as Traverse the left subtree. Visit root. Traverse the right subtree. For this in-order traversal, start from

View Solution →

Kitty's Calculations on a Tree

Kitty has a tree, T , consisting of n nodes where each node is uniquely labeled from 1 to n . Her friend Alex gave her q sets, where each set contains k distinct nodes. Kitty needs to calculate the following expression on each set: where: { u ,v } denotes an unordered pair of nodes belonging to the set. dist(u , v) denotes the number of edges on the unique (shortest) path between nodes a

View Solution →

Is This a Binary Search Tree?

For the purposes of this challenge, we define a binary tree to be a binary search tree with the following ordering requirements: The data value of every node in a node's left subtree is less than the data value of that node. The data value of every node in a node's right subtree is greater than the data value of that node. Given the root node of a binary tree, can you determine if it's also a

View Solution →

Square-Ten Tree

The square-ten tree decomposition of an array is defined as follows: The lowest () level of the square-ten tree consists of single array elements in their natural order. The level (starting from ) of the square-ten tree consists of subsequent array subsegments of length in their natural order. Thus, the level contains subsegments of length , the level contains subsegments of length , the

View Solution →

Balanced Forest

Greg has a tree of nodes containing integer data. He wants to insert a node with some non-zero integer value somewhere into the tree. His goal is to be able to cut two edges and have the values of each of the three new trees sum to the same amount. This is called a balanced forest. Being frugal, the data value he inserts should be minimal. Determine the minimal amount that a new node can have to a

View Solution →

Jenny's Subtrees

Jenny loves experimenting with trees. Her favorite tree has n nodes connected by n - 1 edges, and each edge is ` unit in length. She wants to cut a subtree (i.e., a connected part of the original tree) of radius r from this tree by performing the following two steps: 1. Choose a node, x , from the tree. 2. Cut a subtree consisting of all nodes which are not further than r units from node x .

View Solution →