Maximize the Number of Equivalent Pairs After Swaps - Google Top Interview Questions

Problem Statement :

You are given a list of integers of the same length A and B. 

You are also given a two-dimensional list of integers C where each element is of the form [i, j] which means that you can swap A[i] and A[j] as many times as you want.

Return the maximum number of pairs where A[i] = B[i] after the swapping.


n ≤ 100,000 where n is the length of A and B

m ≤ 100,000 where m is the length of C

Example 1


A = [1, 2, 3, 4]

B = [2, 1, 4, 3]

C = [

    [0, 1],

    [2, 3]





We can swap A[0] with A[1] then A[2] with A[3].

Solution :


                        Solution in C++ :

class UnionFind {
    vector<int> parents, rank;

    UnionFind(int n) {
        for (int i = 0; i < n; i++) {
            parents[i] = i;
            rank[i] = 1;

    int find(int node) {
        int root = node;

        while (root != parents[root]) {
            root = parents[root];

        // Path compression
        while (node != root) {
            int temp = parents[node];
            parents[node] = root;
            node = temp;

        return root;

    void unify(int a, int b) {
        int rootA = find(a);
        int rootB = find(b);

        if (rootA == rootB) return;

        // Union by rank
        if (rank[rootA] > rank[rootB]) {
            parents[rootB] = rootA;
        } else if (rank[rootB] > rank[rootA]) {
            parents[rootA] = rootB;
        } else {
            parents[rootB] = rootA;

    vector<int> get_parents_array() {
        return parents;

// Time and Space: O(N)
int solve(vector<int>& A, vector<int>& B, vector<vector<int>>& C) {
    int n = A.size();
    UnionFind union_find(n);

    for (vector<int>& edge : C) {
        union_find.unify(edge[0], edge[1]);  // Do unions to form groups

    vector<int> parents = union_find.get_parents_array();
    unordered_map<int, vector<int>> grp_map;

    for (int i = 0; i < n; i++) {
        int parent = union_find.find(i);
        grp_map[parent].push_back(i);  // Map parents to list of indices in their group

    int count = 0;

    for (auto& grp : grp_map) {  // For each group
        vector<int>& indices = grp.second;
        unordered_map<int, int> value_map;

        for (int idx : indices) {  // Map values found

        for (int idx : indices) {  // For same indices check how many matched values are found
            if (--value_map[B[idx]] >= 0) {

    return count;

                        Solution in Java :

import java.util.*;

class Solution {
    class DisjointSet {
        int node;
        DisjointSet parent;
        public DisjointSet(int val) {
            this.node = val;
            this.parent = this;

    private Map<Integer, DisjointSet> map = new HashMap();
    private Map<Integer, List<Integer>> swappableMap = new HashMap();
    public int solve(int[] A, int[] B, int[][] C) {
        int count = 0;
        if (A.length == 0 || B.length == 0)
            return 0;
        for (int i = 0; i < A.length; i++) map.put(i, new DisjointSet(i));

        for (int[] arr : C) {
            int idx1 = arr[0];
            int idx2 = arr[1];
            union(idx1, idx2);

        for (int i = 0; i < A.length; i++) {
            DisjointSet set = map.get(i);
            DisjointSet par = find(set);
            swappableMap.computeIfAbsent(par.node, k -> new ArrayList()).add(i);

        for (int key : swappableMap.keySet()) {
            List<Integer> list = swappableMap.get(key);

            Map<Integer, Integer> freq1 = new HashMap();
            Map<Integer, Integer> freq2 = new HashMap();

            for (int i = 0; i < list.size(); i++) {
                int idx = list.get(i);
                freq1.put(A[idx], freq1.getOrDefault(A[idx], 0) + 1);
                freq2.put(B[idx], freq2.getOrDefault(B[idx], 0) + 1);
            for (int num : freq1.keySet()) {
                count += (Math.min(freq1.get(num), freq2.getOrDefault(num, 0)));
        return count;

    private void union(int idx1, int idx2) {
        DisjointSet set1 = map.get(idx1);
        DisjointSet set2 = map.get(idx2);

        DisjointSet f1 = find(set1);
        DisjointSet f2 = find(set2);

        if (f1.node == f2.node)
        f1.parent = f2;

    private DisjointSet find(DisjointSet set) {
        if (set.parent == set)
            return set;
        return set.parent = find(set.parent);

                        Solution in Python : 
class Solution:
    def solve(self, A, B, edges):
        N = len(A)
        graph = [[] for _ in range(N)]
        for u, v in edges:

        ans = 0
        seen = [False] * N
        for u in range(N):
            if not seen[u]:
                queue = [u]
                seen[u] = True
                for node in queue:
                    for nei in graph[node]:
                        if not seen[nei]:
                            seen[nei] = True

                count = Counter(B[i] for i in queue)
                for i in queue:
                    if count[A[i]]:
                        count[A[i]] -= 1
                        ans += 1

        return ans

