Max Transform

Problem Statement :

Transforming data into some other data is typical of a programming job. This problem is about a particular kind of transformation which we'll call the max transform.

Let  be a zero-indexed array of integers. For , let  denote the subarray of  from index  to index , inclusive.

Let's define the max transform of  as the array obtained by the following procedure:

Let  be a list, initially empty.
For  from  to :
For  from  to :
Let .
Append  to the end of .
Return .
The returned array is defined as the max transform of . We denote it by .

Complete the function solve that takes an integer array  as input.

Given an array , find the sum of the elements of , i.e., the max transform of the max transform of . Since the answer may be very large, only find it modulo .

Input Format

The first line of input contains a single integer  denoting the length of .

The second line contains  space-separated integers  denoting the elements of .


1  <=  n  <=  2x 10^5
1  <=  Ai  <= 10^6

Output Format

Print a single line containing a single integer denoting the answer.

Solution :


                            Solution in C :

In   C++  :

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int ui;
typedef long double ld;
typedef pair<int, int> ii;
typedef pair<ii, ii> iii;
ll MOD = 1e9 + 7;
const ld E = 1e-9;
#define null NULL
#define ms(x) memset(x, 0, sizeof(x))
#ifndef LOCAL
#define endl "\n"
#define sync ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define _read(x) freopen(x, "r", stdin)
#define _write(x) freopen(x, "w", stdout)
#define files(x) _read(x ".in"); _write(x ".out")
#define filesdatsol(x) _read(x ".DAT"); _write(x ".SOL")
#define output _write("output.txt")
#define input _read("input.txt")
#define prev time_prev
#ifndef M_PI
#define M_PI acos(-1)
#define remove tipa_remove
#define next tipa_next
#define left tipa_left
#define right tipa_right
#define mod % MOD
#define y1 hello_world
unsigned char ccc;
bool _minus = false;
template<typename T>
inline T sqr(T t) {
    return (t * t);
inline void read(ll &n) {
    n = 0;
    _minus = false;
    while (true) {
        ccc = getchar();
        if (ccc == ' ' || ccc == '\n')
        if (ccc == '-') {
            _minus = true;
        n = n * 10 + ccc - '0';
    if (_minus)
        n *= -1;
inline bool read(int &n) {
    n = 0;
    _minus = false;
    while (true) {
        ccc = getchar();
        if (ccc == ' ' || ccc == '\n') {
            if (ccc == '\n')
                return true;
        if (ccc == '-') {
            _minus = true;
        n = n * 10 + ccc - '0';
    if (_minus)
        n *= -1;
    return false;
char wwww[19];
int kkkk;
inline void write(ll y) {
    long long x = y;
    kkkk = 0;
    if (x < 0) {
        x *= -1;
    if (!x)
        wwww[++kkkk] = '0';
        while (x) {
            wwww[kkkk] = char(x % 10 + '0');
            x /= 10;
    for (int i = kkkk; i >= 1; --i)
#ifdef LOCAL
//#define __DEBUG
#ifdef __DEBUG
#define dbg if(1)
#define dbg if(0)

inline ll sum(ll n){
    return (n * (n + 1)) / 2;

__int128 ans;

const int MAX = 4e5 + 10;

int ar[MAX];
ll tt[MAX];
ll ttt[MAX];
int n;

ll culc(int a){
    return tt[a] % MOD;

int t[MAX << 2];

void build(int v, int l, int r){
    if(l == r){
        t[v] = l;
    int x = (l + r) >> 1;
    build(v << 1, l, x);
    build(v << 1 | 1, x + 1, r);
    t[v] = (ar[t[v << 1]] > ar[t[v << 1 | 1]] ? t[v << 1] : t[v << 1 | 1]);

int get_max(int v, int tl, int tr, int l, int r){
    if(l <= tl && tr <= r){
        return t[v];
    if(tr < l || r < tl){
        return 0;
    int x = (tl + tr) >> 1;
    int a = get_max(v << 1, tl, x, l, r);
    int b = get_max(v << 1 | 1, x + 1, tr, l, r);
    return (ar[a] > ar[b] ? a : b);

ll culc(int a, int b){
    ll ans = sum(a);
    int r = a + b;
    int l = a + b - 2 * min(a, b);
    l = max(l, 0);
    ans += ttt[r] - ttt[l];
    int e = min(a, b);
    a -= e;
    b -= e;
    ans += culc(a) + culc(b);
    return ans;

int get_max(int l, int r){
    return get_max(1, 1, n, l, r);

int get_max(int l1, int r1, int l2, int r2){
    int a = get_max(l1, r1);
    int b = get_max(l2, r2);
    return (ar[a] > ar[b] ? a : b);

ll solve_1(int l, int r){
    if(l > r)
        return 0;
    if(l == r){
        ans += ar[l];
        return 1;
    int pos = get_max(l, r);
    ll has = culc(r - l + 1);
    has -= solve_1(l, pos - 1);
    has -= solve_1(pos + 1, r);
    has %= MOD;
    ans += has * ar[pos];
    return culc(r - l + 1);

ll solve_2(int r, int l){
    if(r == 0){
        return solve_1(l, n);
    if(l == n + 1){
        return solve_1(1, r);
    int pos = get_max(1, r, l, n);
    ll has = culc(r, n - l + 1);
    if(pos <= r){
        has -= solve_2(pos - 1, l);
        has -= solve_1(pos + 1, r);
        has -= solve_2(r, pos + 1);
        has -= solve_1(l, pos - 1);
    has %= MOD;
    ans += has * ar[pos];
    return culc(r, n - l + 1);

namespace solve_long {
    vector<int> get(vector<int> v){
        vector<int> t;
        int n = (int) v.size();
        for(int k = 0; k < n; k++){
            for(int i = 0; i < n - k; i++){
                int ans = 0;
                for(int j = i; j <= i + k; j++){
                    ans = max(ans, v[j]);
        return t;
    int max_val = 0;
    int get_max(vector<int> &v, int l, int r){
        int pos = l;
        for(int i = l + 1; i <= r && v[pos] != max_val; i++){
            if(v[i] > v[pos]){
                pos = i;
        return pos;
    ll sum(vector<int> &v, int l, int r){
        if(l > r)
            return 0;
        int pos = get_max(v, l, r);
        ll ans = (pos - l + 1) * 1LL * (r - pos + 1) * v[pos];
        return ans + sum(v, l, pos - 1) + sum(v, pos + 1, r);
    ll sum(vector<int> v){
        max_val = 0;
        for(int a : v){
            max_val = max(max_val, a);
        return sum(v, 0, (int) v.size() - 1);
    __int128 solve(vector<int> v){
        return sum((get(v))) % MOD;

ostream& operator<<(ostream &cout, __int128 a){
    string s = "";
        s += char(a % 10 + '0');
        a /= 10;
    reverse(s.begin(), s.end());
    cout << s;
    return cout;

__int128 solve_ok(vector<int> v){
    n = (int) v.size();
    for(int i = 1; i <= n; i++){
        ar[i] = v[i - 1];
    build(1, 1, n);
    ans = 0;
    __int128 g = n;
    g = (g * (g + 1)) / 2;
    g = (g * (g + 1)) / 2;
    int pos = get_max(1, n);
    g -= solve_2(pos - 1, pos + 1);
    ans += g * ar[pos];
    return ans % MOD;

int main() {
    srand((unsigned int) time(NULL));
    cout << fixed;
#ifdef LOCAL
    ll start = (ll) clock();
    for(int i = 1; i < MAX; i++){
        tt[i] = sum(i) + tt[i - 1];
    for(int i = 1; i < MAX; i++){
        ttt[i] = sum(i);
        if(i > 1){
            ttt[i] += ttt[i - 2];
    int n;
    cin >> n;
    vector<int> v(n);
    for(int i = 0; i < n; i++){
        cin >> v[i];
    cout << solve_ok(v) << endl;
#ifdef LOCAL
    cout << "=====" << endl;
    cout << (clock() - start) * 1. / CLOCKS_PER_SEC << endl;
    cout << clock() << endl;
    cout << "=====" << endl;
   // assert(false);

In   Java :

import java.util.Arrays;
import java.util.InputMismatchException;

public class MaxTransform {
InputStream is;
PrintWriter out;
String INPUT = "";
int mod = 1000000007;

void solve()
int n = ni();
int[] a = na(n);
if(n == 1){
long all = (long)n*(n+1)/2%mod;
all = all*(all+1)/2%mod;
for(int i = 1;i <= n-1;i++){
all -= (long)(n-i+1+n-i)*(n-i+1+n-i+1)/2;
if(i < n-1)all += (long)(n-i)*(n-i+1)/2;
all %= mod;
int amax = 0;
for(int v : a)amax = Math.max(amax, v);
long ans = all*amax;
ans %= mod;
int[] b = new int[n];
for(int i = 0;i < n;i++)b[i] = -a[i];
SegmentTreeRMQPos st = new SegmentTreeRMQPos(b);
imos = new long[n+3];
dfs(0, n, a, st);

for(int i = 0;i <= n+2;i++){
imos[i] %= mod;
for(int i = 0;i <= n+1;i++){
imos[i+1] += imos[i];
imos[i+1] %= mod;
for(int i = 0;i <= n+1;i++){
imos[i+1] += imos[i];
imos[i+1] %= mod;
for(int i = 1;i <= n;i++){
ans += (long)i*imos[i];
ans %= mod;

int[] sufs = new int[n];
for(int i = 0;i < n;i++){
sufs[i] = a[n-1-i];
if(i > 0)sufs[i] = Math.max(sufs[i], sufs[i-1]);
int[] pres = new int[n];
for(int i = 0;i < n;i++){
pres[i] = a[i];
if(i > 0)pres[i] = Math.max(pres[i], pres[i-1]);
long[] cpres = new long[n+1];
for(int i = 0;i < n;i++){
cpres[i+1] = cpres[i] + pres[i];
long[] csufs = new long[n+1];
for(int i = 0;i < n;i++){
csufs[i+1] = csufs[i] + sufs[i];

long temp = 0;
for(int i = n-1;i >= 1;i--){
if(i < n-1)temp += maxsum(sufs[i-1], pres, i+1, n, cpres);
temp += maxsum(pres[i], sufs, i-1, n, csufs);
ans += temp;
ans %= mod;
ans %= mod;
if(ans < 0)ans += mod;

long maxsum(int v, int[] a, int l, int r, long[] cum)
int ind = Arrays.binarySearch(a, l, r, v);
if(ind < 0)ind = -ind-1;
long ret = cum[r] - cum[ind] + (long)(ind-l) * v;
ret %= mod;
return ret;

long[] imos;

void dfs(int l, int r, int[] a, SegmentTreeRMQPos st)
if(l >= r)return;
st.minx(l, r);
int arg = st.minpos;
imos[1] += a[arg];
imos[arg-l+2] -= a[arg];
imos[r-arg+1] -= a[arg];
imos[r-l+2] += a[arg];
dfs(l, arg, a, st);
dfs(arg+1, r, a, st);

public static class SegmentTreeRMQPos {
public int M, H, N;
public int[] st;
public int[] pos;

public SegmentTreeRMQPos(int n)
N = n;
M = Integer.highestOneBit(Math.max(N-1, 1))<<2;
H = M>>>1;
st = new int[M];
pos = new int[M];
for(int i = 0;i < N;i++)pos[H+i] = i;
Arrays.fill(st, 0, M, Integer.MAX_VALUE);
for(int i = H-1;i >= 1;i--)propagate(i);

public SegmentTreeRMQPos(int[] a)
N = a.length;
M = Integer.highestOneBit(Math.max(N-1, 1))<<2;
H = M>>>1;
st = new int[M];
pos = new int[M];
for(int i = 0;i < N;i++){
st[H+i] = a[i];
pos[H+i] = i;
Arrays.fill(st, H+N, M, Integer.MAX_VALUE);
for(int i = H-1;i >= 1;i--)propagate(i);

public void update(int pos, int x)
st[H+pos] = x;
for(int i = (H+pos)>>>1;i >= 1;i >>>= 1)propagate(i);

private void propagate(int i)
if(st[2*i] <= st[2*i+1]){
st[i] = st[2*i];
pos[i] = pos[2*i];
st[i] = st[2*i+1];
pos[i] = pos[2*i+1];

public int minpos;
public int minval;

public int minx(int l, int r){
minval = Integer.MAX_VALUE;
minpos = -1;
if(l >= r)return minval;
while(l != 0){
int f = l&-l;
if(l+f > r)break;
int v = st[(H+l)/f];
if(v < minval){
    minval = v;
    minpos = pos[(H+l)/f];
l += f;

while(l < r){
int f = r&-r;
int v = st[(H+r)/f-1];
if(v < minval){
    minval = v;
    minpos = pos[(H+r)/f-1];
r -= f;
return minval;

public int min(int l, int r){ 
minpos = -1;
minval = Integer.MAX_VALUE;
min(l, r, 0, H, 1);
return minval;

private void min(int l, int r, int cl, int cr, int cur)
if(l <= cl && cr <= r){
if(st[cur] < minval){
    minval = st[cur];
    minpos = pos[cur];
int mid = cl+cr>>>1;
if(cl < r && l < mid)min(l, r, cl, mid, 2*cur);
if(mid < r && l < cr)min(l, r, mid, cr, 2*cur+1);

void run() throws Exception
is = INPUT.isEmpty() ? : 
new ByteArrayInputStream(INPUT.getBytes());
out = new PrintWriter(System.out);

long s = System.currentTimeMillis();

public static void main(String[] args) 
throws Exception { new MaxTransform().run(); 

private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;

private int readByte()
if(lenbuf == -1)throw new InputMismatchException();
if(ptrbuf >= lenbuf){
ptrbuf = 0;
try { lenbuf =; } 
catch (IOException e) { throw new InputMismatchException(); }
if(lenbuf <= 0)return -1;
return inbuf[ptrbuf++];

private boolean isSpaceChar(int c) 
{ return !(c >= 33 && c <= 126); }
private int skip() { int b; 
while((b = readByte()) != -1 && 
isSpaceChar(b)); return b; }

private double nd()
{ return Double.parseDouble(ns()); }
private char nc() 
{ return (char)skip(); }

private String ns()
int b = skip();
StringBuilder sb = new StringBuilder();
b = readByte();
return sb.toString();

private char[] ns(int n)
char[] buf = new char[n];
int b = skip(), p = 0;
while(p < n && !(isSpaceChar(b))){
buf[p++] = (char)b;
b = readByte();
return n == p ? buf : Arrays.copyOf(buf, p);

private char[][] nm(int n, int m)
char[][] map = new char[n][];
for(int i = 0;i < n;i++)map[i] = ns(m);
return map;

private int[] na(int n)
int[] a = new int[n];
for(int i = 0;i < n;i++)a[i] = ni();
return a;

private int ni()
int num = 0, b;
boolean minus = false;
while((b = readByte()) != -1 && 
!((b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();

if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
return minus ? -num : num;
b = readByte();

private long nl()
long num = 0;
int b;
boolean minus = false;
while((b = readByte()) != -1 && !(
(b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();

if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
return minus ? -num : num;
b = readByte();

private static void tr(Object... o) 
{ System.out.println(Arrays.deepToString(o)); }

In   C  :

#pragma GCC optimize ("Ofast")
#pragma GCC target ("sse4")
const int mod = 1000000007, _2 = 500000004;
int N, MX = 0, tp, a[200010], 
i_1[200010], st[200010], mxl[200010],
 mxr[200010], sxl[200010], sxr[200010];
long long M, CNT, ANS = 0;
void calc(int w, int x, int y)
if( x < y )
int temp = x;
x = y;
y = temp;
int k;
if( x == y )
k = ( ( (long long)( x + y ) * i_1[y] % mod - 
(long long)x * x % mod ) % mod + mod ) % mod;
k = ( ( (long long)y * ( i_1[x-1] - i_1[y] ) % 
mod + (long long)( x + y ) * i_1[y] % 
mod ) % mod + mod ) % mod;
ANS = ( ANS + (long long)w * k ) % mod;
CNT -= k;
if( CNT < 0 )
CNT += mod;
void calcl(int w, int x, int y)
if( x == 1 || y == 0 )
int k;
if( y < x )
k = i_1[y];
k = ( i_1[x-1] + (long long)( 
    y - x + 1 ) * ( x - 1 ) ) % mod;
ANS = ( ANS + (long long)w * k ) % mod;
CNT -= k;
if( CNT < 0 )
CNT += mod;
void calcr(int w, int x, int y)
if( x == 0 || y == 1 )
int k;
if( y + 1 <= x )
k = i_1[y-1];
k = ( i_1[x] + (long long)( y - x - 1 ) * x ) % mod;
ANS = ( ANS + (long long)w * k ) % mod;
CNT -= k;
if( CNT < 0 )
CNT += mod;
int main()
int p;
scanf("%d", &N);
for( int i = 1 ; i <= N ; i++ )
scanf("%d", &a[i]);
MX = MX > a[i] ? MX : a[i];
M = ( (long long)N * 
( N + 1 ) >> 1 ) % mod;
M = (long long)M * ( M + 1 ) % mod * _2 % mod;
CNT = M;
for( int i = 1 ; i <= N ; i++ )
i_1[i] = ( i_1[i-1] + i ) % mod;
for( int i = 1 ; i <= N ; i++ )
sxl[i] = sxl[i-1] > a[i] ? sxl[i-1] : a[i];
for( int i = N ; i ; i-- )

sxr[i] = sxr[i+1] > a[i] ? sxr[i+1] : a[i];
tp = 0;
for( int i = 1 ; i <= N ; i++ )
while( tp > 0 && a[st[tp]] <= a[i] )
mxl[i] = st[tp] + 1;
mxl[i] = 1;
st[++tp] = i;
tp = 0;
for( int i = N ; i ; i-- )
while( tp > 0 && a[st[tp]] < a[i] )
mxr[i] = st[tp] - 1;
mxr[i] = N;
st[++tp] = i;
for( int i = 1 ; i <= N ; i++ )
calc(a[i], i-mxl[i]+1, mxr[i]-i+1);
p = N;
for( int i = 1 ; i <= N ; i++ )
int g = sxl[i];
while( p > i && sxr[p] < g )
while( p < i )
calcl(g, i, N-p);
p = 1;
for( int i = N ; i ; i-- )
int g = sxr[i];
while( p < i && sxl[p] <= g )
while( p > i )
calcr(g, N-i+1, p-1);
CNT = ( CNT % mod + mod ) % mod;
ANS = ( ANS + (long long)CNT * MX ) % mod;
printf("%lld", ANS);
return 0;

In   Python 3 :


import math
import os
import random
import re
import sys

# Complete the solve function below.

import math
import os
import random
import re
import sys
from decimal import Decimal
def t1(n):
    return Decimal(n * (n + 1) / 2)

def t2(n):
    return Decimal(n * (n + 1) * (n + 2) / 6)

def u2(n):
    return Decimal(n * (n + 2) * (2 * n + 5) / 24)

def countzip(a, b):
    return u2(a + b) - u2(abs(a - b)) + t2(abs(a - b))

def countends(x, n, ex):
    return countzip(n, ex) - countzip(x, ex) - countzip(n - 1 - x, 0)

def countsplit(x, n):
    return t1(t1(n)) - t1(x) - countzip(n - x - 1, x - 1)

K = 20
lg = [0] * (1 << K)
for i in range(K):
    lg[1 << i] = i
for i in range(1, 1 << K):
    lg[i] = max(lg[i], lg[i - 1])

def make_rangemax(A):
    n = len(A)
    assert 1 << K > n

    key = lambda x: A[x]
    mxk = []
    for k in range(K - 1):
        for i in range(n - (1 << k)):
            mxk[k + 1][i] = max(
            mxk[k][i], mxk[k][i + (1 << k)],

    def rangemax(i, j):
        k = lg[j - i]
        return max(mxk[k][i], mxk[k][j - (1 << k)], key=key)

    return rangemax

def brutesolo(A):
    rangemax = make_rangemax(A)
    stack = [(0, len(A))]
    ans = 0
    while stack:
        i, j = stack.pop()
        if i != j:
            x = rangemax(i, j)
            stack.append((i, x))
            stack.append((x + 1, j))
            ans += A[x] * (x - i + 1) * (j - x)
    return ans

def make_brute(A):
    rangemax = make_rangemax(A)

    def brute(i, j):
        stack = [(i, j)]
        ans = 0
        while stack:
            i, j = stack.pop()
            if i != j:
                x = rangemax(i, j)
                stack.append((i, x))
                stack.append((x + 1, j))
                ans += A[x] * countends(x - i, j - i, 0)
        return ans

    return brute, rangemax

def ends(A, B):
    brutea, rangemaxa = make_brute(A)
    bruteb, rangemaxb = make_brute(B)

    stack = [(len(A), len(B))]
    ans = 0
    while stack:
        i, j = stack.pop()
        if i == 0:
            ans += bruteb(0, j)
        elif j == 0:
            ans += brutea(0, i)
            x = rangemaxa(0, i)
            y = rangemaxb(0, j)
            if A[x] < B[y]:
                ans += bruteb(y + 1, j)
                ans += B[y] * countends(y, j, i)
                stack.append((i, y))
                ans += brutea(x + 1, i)
                ans += A[x] * countends(x, i, j)
                stack.append((x, j))

    return ans

def maxpairs(a):
    return [max(x, y) for x, y in zip(a, a[1:])]

def solve(A):
    n = len(A)
    x = max(range(n), key=lambda x: A[x])
    return (int((brutesolo(A[:x]) +
    ends(A[x + 1:][::-1], maxpairs(A[:x])) + 
    A[x] * countsplit(x, n))%(10**9+7)))

if __name__ == '__main__':
    fptr = open(os.environ['OUTPUT_PATH'], 'w')

    n = int(input())

    A = list(map(int, input().rstrip().split()))

    result = solve(A)

    fptr.write(str(result) + '\n')


