Tree Pruning
Problem Statement :
A tree, t, has n vertices numbered from 1 to n and is rooted at vertex 1. Each vertex i has an integer weight, wi, associated with it, and t's total weight is the sum of the weights of its nodes. A single remove operation removes the subtree rooted at some arbitrary vertex u from tree t. Given t, perform up to k remove operations so that the total weight of the remaining vertices in t is maximal. Then print t's maximal total weight on a new line. Note: If t's total weight is already maximal, you may opt to remove 0 nodes. Input Format The first line contains two space-separated integers, n and k, respectively. The second line contains n space-separated integers describing the respective weights for each node in the tree, where the ith integer is the weight of the ith vertex. Each of the n-1 subsequent lines contains a pair of space-separated integers, u and v, describing an edge connecting vertex u to vertex v. Constraints 2 <= n <= 10^5 1 <= k <= 200 1 <= i <=n -10^9 <= wi <= 10^9 Output Format Print a single integer denoting the largest total weight of t's remaining vertices.
Solution :
Solution in C :
In C++ :
#include <fstream>
#include <iostream>
#include <vector>
#include <bitset>
using namespace std;
const int NMAX = 100004;
const long long INF = 1LL<<60;
vector <int> Tree[NMAX], Level[NMAX];
long long dp[NMAX][201], sum[NMAX];
int n, Father[NMAX], v[NMAX], val[NMAX], First[NMAX], Last[NMAX], ind;
inline void DFS(const int node,const int father){
First[node] = ++ind;
v[ind] = node;
for(vector < int >::iterator it = Tree[node].begin();it != Tree[node].end();++it)
if(*it != father)
DFS(*it,node);
Last[node] = ind;
}
int main(){
int n, k;
cin.sync_with_stdio(false);
cin >> n >> k;
for(int i = 1;i <= n; ++i)
cin >> val[i];
for(int i=1;i<n;++i){
int x,y;
cin >> x >> y;
Tree[x].push_back(y);
Tree[y].push_back(x);
}
DFS(1,0);
for(int i = 1;i <= n; ++i){
for(int j=0;j<=k;++j)
dp[i][j] = -INF;
}
dp[1][0] = 0;
for(int i = 1;i <= n; ++i)
{
int node = v[i];
for(int j = 0;j <= k; ++j)
if(dp[i][j]!=-INF)
{
dp[i + 1][j] =max(dp[i+1][j],dp[i][j]+val[node]);
if(j+1<=k)
dp[Last[node]+1][j+1] = max(dp[i][j],dp[Last[node]+1][j+1]);
}
}
long long sol = 0;
for(int j = 0;j <= k;++j)
sol = max(sol,dp[n+1][j]);
cout<<sol<<"\n";
return 0;
}
In Java :
import java.util.List;
import java.io.IOException;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.ArrayList;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.InputStream;
/**
* Built using CHelper plug-in
* Actual solution is at the top
*/
public class Solution {
public static void main(String[] args) {
InputStream inputStream = System.in;
OutputStream outputStream = System.out;
InputReader in = new InputReader(inputStream);
PrintWriter out = new PrintWriter(outputStream);
TreePruning solver = new TreePruning();
solver.solve(1, in, out);
out.close();
}
}
class TreePruning {
int[] w;
List<Integer>[] t;
List<Integer>[] dt;
int[] cnt;
long[][] dp;
int k;
public void solve(int testNumber, InputReader in, PrintWriter out) {
int n = in.nextInt();
k = in.nextInt();
dp = new long[n][];
w = new int[n];
t = new List[n];
dt = new List[n];
cnt = new int[n];
for (int i = 0; i < n; i++) {
w[i] = in.nextInt();
t[i] = new ArrayList<>();
dt[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
int u = in.nextInt() - 1;
int v = in.nextInt() - 1;
t[u].add(v);
t[v].add(u);
}
prepare(0, -1);
dfs(0);
long res = Long.MIN_VALUE;
for (int i = 0; i < dp[0].length; i++) {
res = Math.max(res, dp[0][i]);
}
out.println(res);
}
void dfs(int u) {
for (int v : dt[u]) {
dfs(v);
}
long[] d = new long[Math.min(cnt[u], k) + 1];
for (int v : dt[u]) {
long[] nd = new long[d.length];
Arrays.fill(nd, Long.MIN_VALUE / 2);
for (int i = 0; i < d.length; i++) {
for (int j = 0; j < dp[v].length && i + j < nd.length; j++) {
nd[i + j] = Math.max(nd[i + j], d[i] + dp[v][j]);
}
}
d = nd;
}
for (int i = 0; i < d.length; i++) {
d[i] += w[u];
}
if (d.length > 0)
d[1] = Math.max(d[1], 0);
dp[u] = d;
}
void prepare(int u, int p) {
if (p != -1) {
dt[p].add(u);
}
cnt[u] = 1;
for (int v : t[u]) {
if (v != p) {
prepare(v, u);
cnt[u] += cnt[v];
}
}
}
}
class InputReader {
final InputStream is;
final byte[] buf = new byte[1024];
int pos;
int size;
public InputReader(InputStream is) {
this.is = is;
}
public int nextInt() {
int c = read();
while (isWhitespace(c))
c = read();
int sign = 1;
if (c == '-') {
sign = -1;
c = read();
}
int res = 0;
do {
if (c < '0' || c > '9')
throw new InputMismatchException();
res = res * 10 + c - '0';
c = read();
} while (!isWhitespace(c));
return res * sign;
}
int read() {
if (size == -1)
throw new InputMismatchException();
if (pos >= size) {
pos = 0;
try {
size = is.read(buf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (size <= 0)
return -1;
}
return buf[pos++] & 255;
}
static boolean isWhitespace(int c) {
return c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == -1;
}
}
In C :
#include <stdio.h>
#include <stdlib.h>
typedef struct _node{
int x;
struct _node *next;
} node;
void insert_edge(int x,int y);
void dfs(int x);
long long max(long long x,long long y);
int a[100000],b[100000],size[100000],trace[100000]={0},NN=0;
long long dp[100001][201];
node *table[100000]={0};
int main(){
int N,K,x,y,i,j;
long long sum;
scanf("%d%d",&N,&K);
for(i=0;i<N;i++)
scanf("%d",a+i);
for(i=0;i<N-1;i++){
scanf("%d%d",&x,&y);
insert_edge(x-1,y-1);
}
dfs(0);
for(i=0;i<=K;i++)
dp[0][i]=0;
for(i=1,sum=0;i<=N;i++){
sum+=b[i-1];
for(j=0;j<=K;j++)
dp[i][j]=sum;
}
for(i=1,sum=0;i<=N;i++)
for(j=0;j<=K;j++){
if(j!=K)
dp[i+size[i-1]-1][j+1]=max(dp[i+size[i-1]-1][j+1],dp[i-1][j]);
dp[i][j]=max(dp[i][j],dp[i-1][j]+b[i-1]);
}
printf("%lld",dp[N][K]);
return 0;
}
void insert_edge(int x,int y){
node *t;
t=(node*)malloc(sizeof(node));
t->x=y;
t->next=table[x];
table[x]=t;
t=(node*)malloc(sizeof(node));
t->x=x;
t->next=table[y];
table[y]=t;
return;
}
void dfs(int x){
node *t;
int i=NN;
trace[x]=1;
b[NN++]=a[x];
for(t=table[x];t;t=t->next)
if(!trace[t->x])
dfs(t->x);
size[i]=NN-i;
return;
}
long long max(long long x,long long y){
return (x>y)?x:y;
}
In Python3 :
#!/bin/python3
import os
import sys
#
# Complete the treePrunning function below.
#
from collections import defaultdict
INF = -(1e15)
def dfs(x, f, g, k, weights):
dpc = [INF]*(k+1)
dpc[0] = weights[x]
for n in g[x]:
if n == f:
continue
dpn = dfs(n, x, g, k, weights)
dptmp = [INF]*(k+1)
for i in range(k+1):
if dpc[i] == INF:
break
for j in range(0, k-i+1):
if dpn[j] == INF:
break
dptmp[i+j] = max(dptmp[i+j], dpc[i]+dpn[j])
if i+1 <= k:
dptmp[i+1] = max(dptmp[i+1], dpc[i])
dpc = dptmp
return dpc
def treePrunning(k,weights,edges):
g = defaultdict(list)
for u, v in edges:
g[u-1].append(v-1)
g[v-1].append(u-1)
dpn = dfs(0, -1, g, k, weights)
return max(max(dpn),0)
if __name__ == '__main__':
fptr = open(os.environ['OUTPUT_PATH'], 'w')
nk = input().split()
n = int(nk[0])
k = int(nk[1])
weights = list(map(int, input().rstrip().split()))
tree = []
for _ in range(n-1):
tree.append(list(map(int, input().rstrip().split())))
result = treePrunning(k, weights, tree)
fptr.write(str(result) + '\n')
fptr.close()
View More Similar Problems
Cycle Detection
A linked list is said to contain a cycle if any node is visited more than once while traversing the list. Given a pointer to the head of a linked list, determine if it contains a cycle. If it does, return 1. Otherwise, return 0. Example head refers 1 -> 2 -> 3 -> NUL The numbers shown are the node numbers, not their data values. There is no cycle in this list so return 0. head refer
View Solution →Find Merge Point of Two Lists
This challenge is part of a tutorial track by MyCodeSchool Given pointers to the head nodes of 2 linked lists that merge together at some point, find the node where the two lists merge. The merge point is where both lists point to the same node, i.e. they reference the same memory location. It is guaranteed that the two head nodes will be different, and neither will be NULL. If the lists share
View Solution →Inserting a Node Into a Sorted Doubly Linked List
Given a reference to the head of a doubly-linked list and an integer ,data , create a new DoublyLinkedListNode object having data value data and insert it at the proper location to maintain the sort. Example head refers to the list 1 <-> 2 <-> 4 - > NULL. data = 3 Return a reference to the new list: 1 <-> 2 <-> 4 - > NULL , Function Description Complete the sortedInsert function
View Solution →Reverse a doubly linked list
This challenge is part of a tutorial track by MyCodeSchool Given the pointer to the head node of a doubly linked list, reverse the order of the nodes in place. That is, change the next and prev pointers of the nodes so that the direction of the list is reversed. Return a reference to the head node of the reversed list. Note: The head node might be NULL to indicate that the list is empty.
View Solution →Tree: Preorder Traversal
Complete the preorder function in the editor below, which has 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's preorder traversal as a single line of space-separated values. Input Format Our test code passes the root node of a binary tree to the preOrder function. Constraints 1 <= Nodes in the tree <= 500 Output Format Print the tree's
View Solution →Tree: Postorder Traversal
Complete the postorder function in the editor below. It received 1 parameter: a pointer to the root of a binary tree. It must print the values in the tree's postorder traversal as a single line of space-separated values. Input Format Our test code passes the root node of a binary tree to the postorder function. Constraints 1 <= Nodes in the tree <= 500 Output Format Print the
View Solution →