Longest Palindromic Subsequence


Problem Statement :


Steve loves playing with palindromes. He has a string, s, consisting of n lowercase English alphabetic characters (i.e., a through z). He wants to calculate the number of ways to insert exactly 1 lowercase character into string s such that the length of the longest palindromic subsequence of s increases by at least k. Two ways are considered to be different if either of the following conditions are satisfied:

The positions of insertion are different.
The inserted characters are different.
This means there are at most 26*(n+1) different ways to insert exactly  1character into a string of length n.

Given q queries consisting of n, k, and s, print the number of different ways of inserting exactly 1 new lowercase letter into string s such that the length of the longest palindromic subsequence of s increases by at least k.

Input Format

The first line contains a single integer, q, denoting the number of queries. The 2q subsequent lines describe each query over two lines:

1.The first line of a query contains two space-separated integers denoting the respective values of n and k.
2.The second line contains a single string denoting s.
Constraints
1 <= q <= 10
1 <= n <= 3000
0 <= k <= 50
It is guaranteed that s consists of lowercase English alphabetic letters (i.e., a to z) only.



Solution :



title-img


                            Solution in C :

In C++ :





#include <bits/stdc++.h>
#define SZ(X) ((int)(X).size())
#define ALL(X) (X).begin(), (X).end()
#define REP(I, N) for (int I = 0; I < (N); ++I)
#define REPP(I, A, B) for (int I = (A); I < (B); ++I)
#define RI(X) scanf("%d", &(X))
#define RII(X, Y) scanf("%d%d", &(X), &(Y))
#define RIII(X, Y, Z) scanf("%d%d%d", &(X), &(Y), &(Z))
#define DRI(X) int (X); scanf("%d", &X)
#define DRII(X, Y) int X, Y; scanf("%d%d", &X, &Y)
#define DRIII(X, Y, Z) int X, Y, Z; scanf("%d%d%d", &X, &Y, &Z)
#define RS(X) scanf("%s", (X))
#define CASET int ___T, case_n = 1; scanf("%d ", &___T); while (___T-- > 0)
#define MP make_pair
#define PB push_back
#define MS0(X) memset((X), 0, sizeof((X)))
#define MS1(X) memset((X), -1, sizeof((X)))
#define LEN(X) strlen(X)
#define PII pair<int,int>
#define VI vector<int>
#define VPII vector<pair<int,int> >
#define PLL pair<long long,long long>
#define VPLL vector<pair<long long,long long> >
#define F first
#define S second
typedef long long LL;
using namespace std;
const int MOD = 1e9+7;
const int SIZE = 3005;
int dp[SIZE][SIZE];
int dp2[SIZE][SIZE];
char s[SIZE];
int main(){
    CASET{
        DRII(n,K);
        RS(s+1);
        if(K>2)puts("0");
        else if(K==0){
            printf("%d\n",n*26+26);
        }
        else{
            MS0(dp);
            MS0(dp2);
            REPP(i,1,n+1)dp2[i][i]=1;
            REPP(j,1,n){
                for(int k=1;k+j<=n;k++){
                    int ll=k,rr=k+j;
                    if(s[ll]==s[rr])dp2[ll][rr]=max(dp2[ll][rr],dp2[ll+1][rr-1]+2);
                    dp2[ll][rr]=max(dp2[ll][rr],dp2[ll+1][rr]);
                    dp2[ll][rr]=max(dp2[ll][rr],dp2[ll][rr-1]);
                }
            }
            int ma=0;
            for(int i=1;i<n;i++){
                for(int j=n;j>i;j--){
                    if(s[i]==s[j])dp[i][j]=dp[i-1][j+1]+2;
                    else dp[i][j]=max(dp[i-1][j],dp[i][j+1]);
                    ma=max(ma,dp[i][j]);
                }
            }
            REPP(i,1,n+1)ma=max(ma,dp[i-1][i+1]+1);
            int an=0;
            REP(i,n+1){
                int me=0;
                me=dp[i][i+1]+1;
                if(me>=ma+K){
                    an+=26;
                    continue;
                }
                bool used[26]={};
                REPP(j,1,n+1){
                    if(j<=i){
                        if(dp2[j+1][i]+dp[j-1][i+1]+2>=ma+K)used[s[j]-'a']=1;
                    }
                    else{
                        if(dp2[i+1][j-1]+dp[i][j+1]+2>=ma+K)used[s[j]-'a']=1;
                    }
                }
                REP(j,26)
                    if(used[j]){
                        an++;
                    }
            }
            printf("%d\n",an);
        }
    }
    return 0;
}








In Java :





import java.io.*;
import java.util.*;
import java.text.*;
import java.math.*;
import java.util.regex.*;

public class Solution {

    static int longestPalindromicSubsequence(String s, int k)
    {
        int n = s.length();
        
        if (k == 0)
        {
            return (n + 1) * 26;
        }
        
        if (k > 2)
        {
            return 0;
        }
        
        if (n == 1)
        {
            if (k == 1)
            {
                return 2;
            }
            
            return 0;
        }
        
        short pal[][] = computePal(s);
        short rightEnds[][] = computeRightEnds(s);
        short leftEnds[][] = computeLeftEnds(s);
        int sMax = pal[0][n - 1];
        
        boolean ok[][] = new boolean[n + 1][26];
        
        for (int i = 0; i <= n; i++)
        {
            for (int j = 0; j < n; j++)
            {
                char c = s.charAt(j);
                int min;
                int max;
                int middle = 0;
                
                if (i <= j)
                {
                    if (i < j)
                    {
                        middle = pal[i][j - 1];
                    }
                    
                    min = i;
                    max = j;
                }
                else
                {
                    if (j + 1 <= i - 1)
                    {
                        middle = pal[j + 1][i - 1];
                    }
                    
                    min = j;
                    max = i - 1;
                }
                
                int need = sMax + k - middle - 2;
                
                if (need % 2 == 1)
                {
                    need++;
                }
                
                if (rightEnds[min][need / 2] > max + 1 ||
                   leftEnds[max + 1][need / 2] <  min)
                {
                    ok[i][c - 'a'] = true;
                    //System.out.println("ok " + i + " " + c);
                }
            }
        }
        
        /*for (int i = 0; i <= n; i++)
        {
            for (int j = n - 1; j > i; j--)
            {
                int midLength = sMax - pal[i][j - 1];

                if (midLength % 2 == 0 &&
                    rightEnds[i][midLength / 2] > j)
                {
                    ok[i][s.charAt(j) - 'a'] = true;
                }
            }
        }
        
        for (int i = n; i >= 0; i--)
        {
            for (int j = 0; j < i - 1; j++)
            {
                int midLength = sMax - pal[j + 1][i - 1];

                if (midLength % 2 == 0 &&
                    leftEnds[i][midLength / 2] < j)
                {
                    ok[i][s.charAt(j) - 'a'] = true;
                }
            }
        }*/
        
        if (k == 1)
        {
            if (sMax % 2 == 0)
            {
                for (int i = sMax / 2; i < n; i++)
                {
                    if (rightEnds[i][sMax / 2] >= i)
                    {
                        for (int j = 0; j < 26; j++)
                        {
                            ok[i][j] = true;
                        }
                    }
                }
            }
            else
            {
                /*int half = sMax / 2;
                
                for (int i = 0; i < n; i++)
                {
                    int ch = s.charAt(i) - 'a';
                    
                    for (int j = i; j >= half; j--)
                    {
                        if (rightEnds[j][half] > i)
                        {
                            ok[j][ch] = true;
                        }
                        else
                        {
                            break;
                        }
                    }
                    
                    for (int j = i + 1; j <= n - half; j++)
                    {
                        if (leftEnds[j][half] < i)
                        {
                            ok[j][ch] = true;
                        }
                        else
                        {
                            break;
                        }
                    }
                }*/
            }
        }
        
        int total = 0;
        
        for (int i = 0; i <= n; i++)
        {
            for (int j = 0; j < 26; j++)
            {
                if (ok[i][j])
                {
                    total++;
                }
            }
        }
        
        return total;
    }
    
    private static short[][] computeRightEnds(String s)
    {
        short n = (short) s.length();
        short ends[][] = new short[n + 1][n / 2 + 1];

        for (int i = 0; i < ends.length; i++)
        {
            for (int j = 0; j < ends[i].length; j++)
            {
                ends[i][j] = -1;
            }
        }

        ends[0][0] = n;
        
        for (int len = 1; len <= n; len++)
        {
            ends[len][0] = n;
            
            int i = n - 1;
                
            while (i >= 0 && s.charAt(i) != s.charAt(len - 1))
            {
                i--;
            }
            
            for (int c = 1; c <= n / 2 && c <= len; c++)
            {
                ends[len][c] = (short) Math.max(-1, ends[len - 1][c]);
                
                while (i >= ends[len - 1][c - 1])
                {
                    i--;
                    
                    while (i >= 0 && s.charAt(i) != s.charAt(len - 1))
                    {
                        i--;
                    }
                }
                
                if (i >= len)
                {
                    ends[len][c] = (short) Math.max(ends[len][c], i);
                }
            }
        }
        
        return ends;
    }
    
    private static short[][] computeLeftEnds(String s)
    {
        short n = (short) s.length();
        short ends[][] = new short[n + 1][n / 2 + 1];
        
        for (int i = 0; i < ends.length; i++)
        {
            for (int j = 0; j < ends[i].length; j++)
            {
                ends[i][j] = n;
            }
        }
        
        ends[n][0] = -1;
            
        for (int k = n - 1; k >= 0; k--)
        {
            ends[k][0] = -1;

            int i = 0;
                
            while (i < n && s.charAt(i) != s.charAt(k))
            {
                i++;
            }
            
            for (int c = 1; c <= n / 2 && c <= n - k; c++)
            {
                ends[k][c] = (short) Math.min(n, ends[k + 1][c]);
                
                while (i <= ends[k + 1][c - 1])
                {
                    i++;
                    
                    while (i < n && s.charAt(i) != s.charAt(k))
                    {
                        i++;
                    }
                }
                
                if (i < k)
                {
                    ends[k][c] = (short) Math.min(ends[k][c], i);
                }
            }
        }
        
        return ends;
    }
    
    private static short[][] computePal(String s)
    {
        int n = s.length();
        short pal[][] = new short[n][n];
        
        for (int i = 0; i < n; i++)
        {
            pal[i][i] = 1;
        }
        
        for (int d = 1; d < n; d++)
        {
            for (int i = 0, j = d; j < n; i++, j++)
            {
                if (s.charAt(i) == s.charAt(j))
                {
                    if (d == 1)
                    {
                        pal[i][j] = 2;
                    }
                    else
                    {
                        pal[i][j] = (short) (pal[i + 1][j - 1] + 2);
                    }
                }
                else
                {
                    pal[i][j] = (short) Math.max(pal[i][j - 1], pal[i + 1][j]);
                }
            }
        }

        return pal;
    }

    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int q = in.nextInt();
        for(int a0 = 0; a0 < q; a0++){
            int n = in.nextInt();
            int k = in.nextInt();
            String s = in.next();
            int result = longestPalindromicSubsequence(s, k);
            System.out.println(result);
        }
        in.close();
    }
}








In C :





#include<stdio.h>
#define M 3005
int q, n, k;
char s[M];
int in[M][M], out[M][M];
int max(int a, int b)
{
    return a > b ? a : b;
}
int main()
{
    scanf("%d", &q);
    while(q--)
    {
        scanf("%d %d", &n, &k);
        scanf("%s", s);
        if( k == 0 )
        {
            printf("%d\n", ( n + 1 ) * 26);
            continue;
        }
        if( k > 2 )
        {
            printf("0\n");
            continue;
        }
        for( int l = 0 ; l < n ; l++ )
        {
            for( int i = 0 ; i + l < n ; i++ )
            {
                int j = i + l;
                if( i == j )
                {
                    in[i][j] = 1;
                }
                else if( s[i] == s[j] )
                {
                    if( i + 1 < j )
                    {
                        in[i][j] = 2 + in[i+1][j-1];
                    }
                    else
                    {
                        in[i][j] = 2;
                    }
                }
                else
                {
                    in[i][j] = max(in[i+1][j], in[i][j-1]);
                }
            }
        }
        for( int l = n - 1 ; l >= 0 ; l-- )
        {
            for( int i = 0 ; i + l < n ; i++ )
            {
                int j = i + l;
                if( i == j )
                {
                    if( 0 < i && j + 1 < n )
                    {
                        out[i][j] = 1 + out[i-1][j+1];
                    }
                    else
                    {
                        out[i][j] = 1;
                    }
                }
                else if( s[i] == s[j] )
                {
                    if( 0 < i && j + 1 < n )
                    {
                        out[i][j] = 2 + out[i-1][j+1];
                    }
                    else
                    {
                        out[i][j] = 2;
                    }
                }
                else
                {
                    out[i][j] = 0;
                    if( 0 < i )
                    {
                        out[i][j] = max(out[i][j], out[i-1][j]);
                    }
                    if( j + 1 < n )
                    {
                        out[i][j] = max(out[i][j], out[i][j+1]);
                    }
                }
            }
        }
        int cur = in[0][n-1], res = 0;
        for( int i = 0 ; i <= n ; i++ )
        {
            for( char ch = 'a' ; ch <= 'z' ; ch++ )
            {
                int my = ( i == 0 || i == n ) ? 1 : 1 + out[i-1][i];
                for( int j = 0 ; j < i && my < cur + k ; j++ )
                {
                    if( s[j] == ch )
                    {
                        int cand = 2;
                        if( 0 < j && i < n )
                        {
                            cand += out[j-1][i];
                        }
                        if( j + 1 <= i - 1 )
                        {
                            cand += in[j+1][i-1];
                        }
                        my = max(my, cand);
                    }
                }
                for( int j = i ; j < n && my < cur + k ; j++ )
                {
                    if( s[j] == ch )
                    {
                        int cand = 2;
                        if( 0 < i && j + 1 < n )
                        {
                            cand += out[i-1][j+1];
                        }
                        if( i <= j - 1 )
                        {
                            cand += in[i][j-1];
                        }
                        my = max(my, cand);
                    }
                }
                if( my >= cur + k )
                {
                    res++;
                }
            }
        }
        printf("%d\n", res);
    }
    return 0;
}
                        








View More Similar Problems

Tree: Preorder Traversal

Complete the preorder function in the editor below, which has 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's preorder traversal as a single line of space-separated values. Input Format Our test code passes the root node of a binary tree to the preOrder function. Constraints 1 <= Nodes in the tree <= 500 Output Format Print the tree's

View Solution →

Tree: Postorder Traversal

Complete the postorder function in the editor below. It received 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's postorder traversal as a single line of space-separated values. Input Format Our test code passes the root node of a binary tree to the postorder function. Constraints 1 <= Nodes in the tree <= 500 Output Format Print the

View Solution →

Tree: Inorder Traversal

In this challenge, you are required to implement inorder traversal of a tree. Complete the inorder function in your editor below, which has 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's inorder traversal as a single line of space-separated values. Input Format Our hidden tester code passes the root node of a binary tree to your $inOrder* func

View Solution →

Tree: Height of a Binary Tree

The height of a binary tree is the number of edges between the tree's root and its furthest leaf. For example, the following binary tree is of height : image Function Description Complete the getHeight or height function in the editor. It must return the height of a binary tree as an integer. getHeight or height has the following parameter(s): root: a reference to the root of a binary

View Solution →

Tree : Top View

Given a pointer to the root of a binary tree, print the top view of the binary tree. The tree as seen from the top the nodes, is called the top view of the tree. For example : 1 \ 2 \ 5 / \ 3 6 \ 4 Top View : 1 -> 2 -> 5 -> 6 Complete the function topView and print the resulting values on a single line separated by space.

View Solution →

Tree: Level Order Traversal

Given a pointer to the root of a binary tree, you need to print the level order traversal of this tree. In level-order traversal, nodes are visited level by level from left to right. Complete the function levelOrder and print the values in a single line separated by a space. For example: 1 \ 2 \ 5 / \ 3 6 \ 4 F

View Solution →