Union-Find Problem Sets
Union-find data structure is used extensively to solve problems associated with disjoint sets. Two sets are called disjoint sets if they don’t have any common elements. It supports the following operations:
- Find: finding a representative of the set
- Merge: merge two disjoint sets to form a single set
NB: To optimize the find operation, always assign the representative as the node having the highest rank. This is similar to assigning the subtree with a lower height as a child under the subtree with a higher height, thus reducing the complexity. The effective time complexity of these two operations can be reduced to O(α(n)), α being an Ackermann function.
This is useful in solving problems involving relationships, especially the ones that can be described with undirected graphs. Detecting the presence of a cycle in an undirected graph is one such example. For detecing cycle in directed graph, topological sort is used.
This wiki contains a couple of well-defined problems solved via Union-Find data structure.
Graph Valid Tree (178):
int find(vector<int> &parent, int index){
int root = index;
while(parent[root] != root){
root = parent[root];
}
return root;
}
void unionSet(vector<int> &parent,vector<int> &rank, int index1, int index2){
int root1 = find(parent, index1);
int root2 = find(parent, index2);
if (rank[root1] < rank[root2]){
parent[root1] = root2;
} else if(rank[root1] > rank[root2]){
parent[root2] = root1;
} else {
parent[root1] = root2;
rank[root2]++;
}
}
bool validTree(int n, vector<vector<int>> &edges) {
// write your code here
vector<int> parent(n), rank(n,0);
for(int i=0;i<n;i++){
parent[i] = i;
}
for(int i=0;i<edges.size();i++){
int node1=edges[i][0], node2=edges[i][1];
int root1 = find(parent, node1);
int root2 = find(parent, node2);
if(root1 == root2){
return false;
}
unionSet(parent, rank, node1, node2);
}
int rootNodes=0;
for(int i=0;i<parent.size();i++){
if(parent[i] == i){
rootNodes++;
}
}
if(rootNodes>1){
cout<<"false root:"<<rootNodes<<endl;
return false;
}
return true;
}
Number of Islands II (434):
pair<int,int> findRoot(pair<int,int> node){
pair<int,int> root = node;
while(parentMap.find(root) != parentMap.end()){
root = parentMap[root];
}
return root;
}
void unionSets(pair<int,int> index1, pair<int,int> index2){
pair<int,int> root1 = findRoot(index1);
pair<int,int> root2 = findRoot(index2);
if(root1 == root2){
return;
}if(rankMap[root1] < rankMap[root2]){
parentMap[root1] = root2;
} else if(rankMap[root1] > rankMap[root2]){
parentMap[root2] = root1;
} else {
parentMap[root1] = root2;
rankMap[root2]++;
}
groups--;
}
vector<int> numIslands2(int n, int m, vector<Point> &operators) {
// write your code here
parentMap.clear();
rankMap.clear();
groups = 0;
set<pair<int,int>> visitedPairSet;
vector<int> groupArray;
for(int i=0;i<operators.size();i++){
pair<int,int> node = make_pair(operators[i].x, operators[i].y);
if(visitedPairSet.find(node) != visitedPairSet.end()){
groupArray.push_back(groups);
continue;
}
groups++;
if(visitedPairSet.find( make_pair(node.first-1,node.second) ) != visitedPairSet.end() ){
unionSets(node, make_pair(node.first-1,node.second) );
}
if(visitedPairSet.find( make_pair(node.first+1,node.second) ) != visitedPairSet.end() ){
unionSets(node, make_pair(node.first+1,node.second) );
}
if(visitedPairSet.find( make_pair(node.first,node.second-1) ) != visitedPairSet.end() ){
unionSets(node, make_pair(node.first,node.second-1) );
}
if(visitedPairSet.find( make_pair(node.first,node.second+1) ) != visitedPairSet.end() ){
unionSets(node, make_pair(node.first,node.second+1) );
}
visitedPairSet.insert(node);
groupArray.push_back(groups);
}
return groupArray;
}
private:
map<pair<int,int>, pair<int,int>> parentMap;
map<pair<int,int>, int> rankMap;
int groups;
Accounts Merge (1070):
int findRoot(vector<int> &parentArray, int index){
int root = index;
while(parentArray[root] != root){
root = parentArray[root];
}
return root;
}
void unionSets(vector<int> &parentArray, vector<int> &rankArray, int index1, int index2){
int root1 = findRoot(parentArray, index1);
int root2 = findRoot(parentArray, index2);
if(root1 == root2){
return;
}
// cout<<"merge2: "<<index1<<","<<index2<<":"<<root1<<","<<root2<<endl;
if(rankArray[root1] > rankArray[root2]){
parentArray[root2] = root1;
} else if(rankArray[root2] > rankArray[root1]){
parentArray[root1] = root2;
} else {
parentArray[root2] = root1;
rankArray[root1]++;
}
}
vector<vector<string>> accountsMerge(vector<vector<string>> &accounts) {
// write your code here
vector<int> parentArray(accounts.size()), rankArray(accounts.size(),0);
for(int i=0;i<parentArray.size();i++){
parentArray[i] = i;
}
map<string, int> emailMapper;
for(int i=0;i<accounts.size();i++){
string accOwner = accounts[i][0];
for(int j=1;j<accounts[i].size();j++){
if(emailMapper.find(accounts[i][j]) != emailMapper.end()){
// cout<<"merging "<<i<<" with "<<emailMapper[accounts[i][j]]<<" for "<<accounts[i][j]<<endl;
unionSets(parentArray, rankArray, i, emailMapper[accounts[i][j]]);
} else {
emailMapper[accounts[i][j]] = i;
}
}
}
map<int, set<string>> mergedMapper;
for(int i=0;i<accounts.size();i++){
int root = findRoot(parentArray, i);
// cout<<"root of index "<<i<<" : "<<root<<endl;
for(int j=1;j<accounts[i].size();j++){
mergedMapper[root].insert(accounts[i][j]);
}
}
vector<vector<string>> result;
for(auto const &pr : mergedMapper){
vector<string> arr;
arr.push_back(accounts[pr.first][0]);
for(const string &str : pr.second){
arr.push_back(str);
}
result.push_back(arr);
}
return result;
}
Maximum Connected Area (261):
pair<int,int> findRoot(pair<int,int> index){
pair<int,int> root = index;
while(parentMap.find(root) != parentMap.end()){
root = parentMap[root];
}
return root;
}
void mergeSet(pair<int,int> index1, pair<int,int> index2){
// cout<<"merging indexs1 : ("<<index1.first<<","<<index1.second<<") ("<<index2.first<<","<<index2.second<<")"<<endl;
pair<int,int> root1=findRoot(index1), root2=findRoot(index2);
int rank1=rankMap[root1], rank2=rankMap[root2];
if(root1==root2){
return;
}
// cout<<"merging indexs2 : ("<<root1.first<<","<<root1.second<<") ("<<root2.first<<","<<root2.second<<")"<<endl;
if(rank1 > rank2){
parentMap[root2] = root1;
groupSizeMap[root1] += groupSizeMap[root2];
} else if(rank1 < rank2){
parentMap[root1] = root2;
groupSizeMap[root2] += groupSizeMap[root1];
} else {
parentMap[root2] = root1;
groupSizeMap[root1] += groupSizeMap[root2];
rankMap[root1]++;
}
}
bool isValid(vector<vector<int>> &grid, int i,int j){
if(i<0 || i>=grid.size()){
return false;
}
if(j<0 || j>=grid[i].size()){
return false;
}
if(grid[i][j] != 1){
return false;
}
return true;
}
int maxArea(vector<vector<int>> &matrix) {
// write your code here.
parentMap.clear();
rankMap.clear();
groupSizeMap.clear();
for(int i=0;i<matrix.size();i++){
for(int j=0;j<matrix[i].size();j++){
if(matrix[i][j] == 1){
groupSizeMap[make_pair(i,j)] = 1;
}
}
}
for(int i=0;i<matrix.size();i++){
for(int j=0;j<matrix[i].size();j++){
if(matrix[i][j] != 1){
continue;
}
pair<int,int> node = make_pair(i,j);
// cout<<"checking for node: ("<<i<<","<<j<<")"<<endl;
if(isValid(matrix, node.first-1, node.second)){
mergeSet(node, make_pair(node.first-1, node.second));
}
if(isValid(matrix, node.first+1, node.second)){
mergeSet(node, make_pair(node.first+1, node.second));
}
if(isValid(matrix, node.first, node.second-1)){
mergeSet(node, make_pair(node.first, node.second-1));
}
if(isValid(matrix, node.first, node.second+1)){
mergeSet(node, make_pair(node.first, node.second+1));
}
}
}
int maxArea = 0;
for(int i=0;i<matrix.size();i++){
for(int j=0;j<matrix[i].size();j++){
// cout<<"calc current area with index:"<<i<<","<<j<<": "<<endl;
int area = 1;
if(matrix[i][j] == 1){
pair<int,int> root = findRoot(make_pair(i,j));
area = groupSizeMap[root];
} else {
set<pair<int,int>> visitedRootSet;
if(isValid(matrix, i-1, j)){
pair<int,int> root = findRoot(make_pair(i-1,j));
if(visitedRootSet.find(root) == visitedRootSet.end()){
visitedRootSet.insert(root);
area += groupSizeMap[root];
// cout<<"1:"<<area<<endl;
}
}
if(isValid(matrix, i+1, j)){
pair<int,int> root = findRoot(make_pair(i+1,j));
if(visitedRootSet.find(root) == visitedRootSet.end()){
visitedRootSet.insert(root);
area += groupSizeMap[root];
// cout<<"2:"<<area<<endl;
}
}
if(isValid(matrix, i, j-1)){
pair<int,int> root = findRoot(make_pair(i,j-1));
if(visitedRootSet.find(root) == visitedRootSet.end()){
visitedRootSet.insert(root);
area += groupSizeMap[root];
// cout<<"3:"<<area<<endl;
}
}
if(isValid(matrix, i, j+1)){
pair<int,int> root = findRoot(make_pair(i,j+1));
if(visitedRootSet.find(root) == visitedRootSet.end()){
visitedRootSet.insert(root);
area += groupSizeMap[root];
// cout<<"4:"<<area<<endl;
}
}
}
// cout<<"current area with index:"<<i<<","<<j<<": "<<area<<endl;
maxArea = max(maxArea, area);
}
}
return maxArea;
}
private:
map<pair<int,int>, pair<int,int>> parentMap;
map<pair<int,int>, int> rankMap;
map<pair<int,int>, int> groupSizeMap;
Making A Large Island (1391):
pair<int,int> findRoot(pair<int,int> index){
pair<int,int> root = index;
while(parentMap.find(root) != parentMap.end()){
root = parentMap[root];
}
return root;
}
void mergeSet(pair<int,int> index1, pair<int,int> index2){
// cout<<"merging index1: ("<<index1.first<<","<<index1.second<<") ("<<index2.first<<","<<index2.second<<")"<<endl;
pair<int,int> root1=findRoot(index1), root2=findRoot(index2);
int rank1=rankMap[root1], rank2=rankMap[root2];
if(root1==root2){
return;
}
// cout<<"merging index2: ("<<root1.first<<","<<root1.second<<") ("<<root2.first<<","<<root2.second<<")"<<endl;
if(rank1 > rank2){
parentMap[root2] = root1;
groupSizeMap[root1] += groupSizeMap[root2];
} else if(rank1 < rank2){
parentMap[root1] = root2;
groupSizeMap[root2] += groupSizeMap[root1];
} else {
parentMap[root2] = root1;
groupSizeMap[root1] += groupSizeMap[root2];
rankMap[root1]++;
}
}
bool isValid(vector<vector<int>> &grid, int i, int j){
if(i<0 || i>=grid.size()){
return false;
}
if(j<0 || j>=grid[i].size()){
return false;
}
if(grid[i][j] != 1){
return false;
}
return true;
}
int largestIsland(vector<vector<int>> &grid) {
//
parentMap.clear();
rankMap.clear();
groupSizeMap.clear();
for(int i=0;i<grid.size();i++){
for(int j=0;j<grid[i].size();j++){
groupSizeMap[make_pair(i,j)] = grid[i][j];
}
}
for(int i=0;i<grid.size();i++){
for(int j=0;j<grid[i].size();j++){
if(grid[i][j] != 1){
continue;
}
pair<int,int> node = make_pair(i,j);
if(isValid(grid, node.first-1, node.second)){
mergeSet(node, make_pair(node.first-1,node.second));
}
if(isValid(grid, node.first+1, node.second)){
mergeSet(node, make_pair(node.first+1,node.second));
}
if(isValid(grid, node.first, node.second-1)){
mergeSet(node, make_pair(node.first,node.second-1));
}
if(isValid(grid, node.first, node.second+1)){
mergeSet(node, make_pair(node.first,node.second+1));
}
}
}
int maxArea = 0;
for(int i=0;i<grid.size();i++){
for(int j=0;j<grid[i].size();j++){
// cout<<"checking for index: "<<i<<","<<j<<endl;
int area = 1;
set<pair<int,int>> visitedSet;
if(grid[i][j]==1){
pair<int,int> root = findRoot(make_pair(i,j));
// cout<<"root index1 :"<<root.first<<","<<root.second<<endl;
area = groupSizeMap[root];
} else {
if(isValid(grid, i-1, j)){
pair<int,int> root = findRoot(make_pair(i-1,j));
if(visitedSet.find(root) == visitedSet.end()){
visitedSet.insert(root);
area += groupSizeMap[root];
}
}
if(isValid(grid, i+1, j)){
pair<int,int> root = findRoot(make_pair(i+1,j));
if(visitedSet.find(root) == visitedSet.end()){
visitedSet.insert(root);
area += groupSizeMap[root];
}
}
if(isValid(grid, i, j-1)){
pair<int,int> root = findRoot(make_pair(i,j-1));
if(visitedSet.find(root) == visitedSet.end()){
visitedSet.insert(root);
area += groupSizeMap[root];
}
}
if(isValid(grid, i, j+1)){
pair<int,int> root = findRoot(make_pair(i,j+1));
if(visitedSet.find(root) == visitedSet.end()){
visitedSet.insert(root);
area += groupSizeMap[root];
}
}
}
maxArea = max(maxArea, area);
}
}
return maxArea;
}
private:
map<pair<int,int>, pair<int,int>> parentMap;
map<pair<int,int>, int> rankMap, groupSizeMap;
Minimum Number of Visited Lattices in a Matrix (3709):
bool isValid(vector<vector<int>> &grid, int i, int j){
if(i<0 || i>=grid.size()){
return false;
}
if(j<0 || j>=grid[i].size()){
return false;
}
return true;
}
int minimumVisitedLattices(vector<vector<int>> &grid) {
// --- write your code here ---
list<pair<pair<int,int>, int>> que;
set<pair<int,int>> visitedNodeSet;
que.push_back( make_pair( make_pair(0,0), 1));
visitedNodeSet.insert(make_pair(0,0));
while(que.size()!=0){
pair<pair<int,int>, int> node = que.front();
que.pop_front();
// cout<<"exploring node: "<<node.first.first<<","<<node.first.second<<" : "<<node.second<<endl;
if(node.first.first==grid.size()-1 && node.first.second==grid[grid.size()-1].size()-1){
return node.second;
}
for(int i=1;i<=grid[node.first.first][node.first.second];i++){
if(isValid(grid, node.first.first, node.first.second+i)){
if(visitedNodeSet.find(make_pair(node.first.first,node.first.second+i)) == visitedNodeSet.end()){
que.push_back( make_pair(make_pair(node.first.first,node.first.second+i), node.second+1) );
visitedNodeSet.insert(make_pair(node.first.first,node.first.second+i));
// cout<<"inserting node:"<<node.first.first<<","<<node.first.second+i<<endl;
}
}
if(isValid(grid, node.first.first+i, node.first.second)){
if(visitedNodeSet.find(make_pair(node.first.first+i,node.first.second)) == visitedNodeSet.end()){
que.push_back( make_pair(make_pair(node.first.first+i,node.first.second), node.second+1) );
visitedNodeSet.insert(make_pair(node.first.first+i,node.first.second));
// cout<<"inserting node:"<<node.first.first+i<<","<<node.first.second<<endl;
}
}
}
}
return -1;
}
The Minimum String After Swapping (3604):
int findRoot(int index){
int root = index;
while(parentArray[root] != -1){
root = parentArray[root];
}
return root;
}
void mergeSet(int index1,int index2){
int root1=findRoot(index1), root2=findRoot(index2);
int rank1=rankArray[root1], rank2=rankArray[root2];
if(root1 == root2){
return;
}
if(rank1 > rank2){
parentArray[root2] = root1;
} else if(rank1 < rank2){
parentArray[root1] = root2;
} else {
parentArray[root2] = root1;
rankArray[root1]++;
}
}
string minStringAfterSwap(string &s, vector<vector<int>> &pairs) {
// // write your code here
parentArray.resize(s.size(), -1);
rankArray.resize(s.size(), 0);
for(int i=0;i<pairs.size();i++){
mergeSet(pairs[i][0], pairs[i][1]);
}
map<int, vector<char>> posMapper;
for(int i=0;i<s.size();i++){
int root = findRoot(i);
posMapper[root].push_back(s.at(i));
}
for(auto &pr : posMapper){
sort(pr.second.begin(), pr.second.end());
}
// for(auto const &pr: posMapper){
// cout<<"pr:"<<pr.first<<endl;
// for(int i=0;i<pr.second.size();i++){
// cout<<pr.second[i]<<",";
// }
// cout<<endl;
// }
map<int,int> posMapperIndexVisited;
for(auto const &pr: posMapper){
posMapperIndexVisited[pr.first] = -1;
}
string result = s;
for(int i=0;i<s.size();i++){
int root = findRoot(i);
posMapperIndexVisited[root]++;
result.at(i) = posMapper[root][posMapperIndexVisited[root]];
}
return result;
}
private:
vector<int> parentArray,rankArray;
Minimize Malware Spread (1718):
int findRoot(int index){
int root = index;
while(parentArray[root] != -1){
root = parentArray[root];
}
return root;
}
void mergeSet(int index1, int index2){
// cout<<"merging "<<index1<<","<<index2<<endl;
int root1=findRoot(index1), root2=findRoot(index2);
int rank1=rankArray[root1], rank2=rankArray[root2];
if(root1 == root2){
return;
}
// cout<<"merging roots "<<root1<<","<<root2<<endl;
if(rank1 > rank2){
parentArray[root2] = root1;
groupSize[root1] += groupSize[root2];
} else if(rank1 < rank2){
parentArray[root1] = root2;
groupSize[root2] += groupSize[root1];
} else {
parentArray[root2] = root1;
groupSize[root1] += groupSize[root2];
rankArray[root1]++;
}
// cout<<"merged roots "<<root1<<","<<root2<<endl;
}
int minMalwareSpread(vector<vector<int>> &graph, vector<int> &initial) {
// write your code here
parentArray.resize(graph.size(), -1);
rankArray.resize(graph.size(), 0);
groupSize.resize(graph.size(), 1);
for(int i=0;i<graph.size();i++){
for(int j=0;j<graph[i].size();j++){
if(graph[i][j] == 1){
// cout<<"going to merge: "<<i<<","<<j<<endl;
mergeSet(i,j);
}
}
}
sort(initial.begin(),initial.end());
map<int,int> rootNodes;
for(int i=0;i<initial.size();i++){
int root = findRoot(initial[i]);
rootNodes[root]++;
}
int maxConnectedIndex=0, maxConnectedArea=0;
for(int i=0;i<initial.size();i++){
int root = findRoot(initial[i]);
if(rootNodes[root] > 1){
continue;
}
if(groupSize[root] > maxConnectedArea){
maxConnectedArea = groupSize[root];
maxConnectedIndex = i;
}
}
return initial[maxConnectedIndex];
}
private:
vector<int> parentArray, rankArray, groupSize;