Kundu and Tree

Problem Statement :

Kundu is true tree lover. Tree is a connected graph having N vertices and N-1 edges. Today when he got a tree, he colored each edge with one of either red(r) or black(b) color. He is interested in knowing how many triplets(a,b,c) of vertices are there , such that, there is atleast one edge having red color on all the three paths i.e. from vertex a to b, vertex b to c and vertex c to a . Note that (a,b,c), (b,a,c) and all such permutations will be considered as the same triplet.

If the answer is greater than 109 + 7, print the answer modulo (%) 109 + 7.

Input Format
The first line contains an integer N, i.e., the number of vertices in tree.
The next N-1 lines represent edges: 2 space separated integers denoting an edge followed by a color of the edge. A color of an edge is denoted by a small letter of English alphabet, and it can be either red(r) or black(b).

Output Format
Print a single number i.e. the number of triplets.

1 ≤ N ≤ 105
A node is numbered between 1 to N.

Sample Input

1 2 b
2 3 r
3 4 r
4 5 b
Sample Output


Solution :


                            Solution in C :

In C ++ :

#define NDEBUG
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;

#define repeat(n) for (int repc = (n); repc > 0; --repc)
typedef long long int64;

struct UnionFind {
  int n;
  vector<int> dad, rank, size;

  UnionFind(int n) : n(n) {

  void reset() {
    for (int i=0; i<n; ++i) {
      dad[i] = i;
    fill(rank.begin(), rank.end(), 0);
    fill(size.begin(), size.end(), 1);

  int find(int a) {
    int top;
    for (top=a; top != dad[top]; top=dad[top]) ;
    while (a != top) { int x = dad[a]; dad[a] = top; a = x; }
    return top;

  int union_find(int a, int b) {
    a = find(a);
    b = find(b);
    if (a != b) {
      if (rank[a] > rank[b]) {
        dad[b] = a;
        size[a] += size[b];
      } else {
        dad[a] = b;
        size[b] += size[a];
        if (rank[a] == rank[b]) {
      return 1;
    return 0;

int main() {

  int n;
  cin >> n;

  UnionFind uf(n);
  repeat (n-1) {
    int a, b;
    char ch;
    cin >> a >> b >> ws >> ch;
    if (ch == 'b') {
      uf.union_find(a-1, b-1);

  auto choose3 = [](int x) { return int64(x) * (x-1) * (x-2) / 6; };
  auto choose2 = [](int x) { return int64(x) * (x-1) / 2; };

  int64 ans = choose3(n);
  for (int i=0; i<n; ++i) {
    if (uf.find(i) == i) {
      int c = uf.size[i];
      ans -= choose3(c);
      ans -= choose2(c) * (n-c);
  cout << ans % 1000000007 << '\n';
  return 0;

In  Java :

import java.io.*;
import java.util.ArrayList;

public class Solution {

    public static void solve(Input in, PrintWriter out) throws IOException {
        int n = in.nextInt();
        ArrayList<Integer>[] edges = new ArrayList[n];
        for (int i = 0; i < n; ++i) {
            edges[i] = new ArrayList<Integer>();
        for (int i = 0; i < n - 1; ++i) {
            int a = in.nextInt() - 1;
            int b = in.nextInt() - 1;
            if (in.next().equals("b")) {
        boolean[] col = new boolean[n];
        long c1 = 0, c2 = 0, c3 = 0;
        for (int i = 0; i < n; ++i) {
            if (!col[i]) {
                int c = dfs(i, edges, col);
                c3 += c * c2;
                c2 += c * c1;
                c1 += c;
        out.println(c3 % 1000000007);

    private static int dfs(int i, ArrayList<Integer>[] edges, boolean[] col) {
        if (col[i]) {
            return 0;
        col[i] = true;
        int r = 1;
        for (int j : edges[i]) {
            r += dfs(j, edges, col);
        return r;

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        solve(new Input(new BufferedReader(new InputStreamReader(System.in))), out);

    static class Input {
        BufferedReader in;
        StringBuilder sb = new StringBuilder();

        public Input(BufferedReader in) {
            this.in = in;

        public Input(String s) {
            this.in = new BufferedReader(new StringReader(s));

        public String next() throws IOException {
            while (true) {
                int c = in.read();
                if (c == -1) {
                    return null;
                if (" \n\r\t".indexOf(c) == -1) {
            while (true) {
                int c = in.read();
                if (c == -1 || " \n\r\t".indexOf(c) != -1) {
            return sb.toString();

        public int nextInt() throws IOException {
            return Integer.parseInt(next());

        public long nextLong() throws IOException {
            return Long.parseLong(next());

        public double nextDouble() throws IOException {
            return Double.parseDouble(next());

In C :

#define mod 1000000007
#define ll long long
ll nodes[100005],hash[100005],sz[100005];

ll root(ll i)
    return i;
void unon_ins(ll p,ll q)
    ll i,j;

int main()

    long long res,prod1,prod2,sum=0;
    ll n,i,j,k,n_b=0;
    char ch;
        ll x,y;
        scanf("%lld %lld %c",&x,&y,&ch);
        //printf("%d %d %c\n",x,y,ch);

            long long t=hash[i];
    return 0;

In Python3 :

n = int(input())

p = list(range(n))
rank = [0] * n
size = [1] * n

def get(v):
	stack = []
	while p[v] != v:
		v = p[v]
	for u in stack:
		p[u] = v
	return v

def union(v1, v2):
	v1 = get(v1)
	v2 = get(v2)
	if v1 == v2:
	if rank[v1] < rank[v2]:
		v1, v2 = v2, v1
	size[v1] += size[v2]
	p[v2] = v1
	rank[v1] += 1

for _ in range(n - 1):
	x, y, col = input().split()
	if col == 'b':
		union(int(x) - 1, int(y) - 1)

a = [size[i] for i in range(n) if p[i] == i]

MOD = 10**9 + 7
def solve(a):
	s1 = [0]
	for x in a:
		s1.append((s1[-1] + x) % MOD)
	s2 = [0]
	for i, x in enumerate(a):
		s2.append((s2[-1] + x * s1[i]) % MOD)
	s3 = [0]
	for i, x in enumerate(a):
		s3.append((s3[-1] + x * s2[i]) % MOD)
	return s3[-1]


