Swap Sort - Editorial

About the Problem

Setter(s) Vichitr Gandas
Tester(s) Mayank Pugalia
Difficulty Cakewalk
Topics Observation, Sorting
Practice Link
Contest Link

Solution Idea

Make a copy of the original array and sort it. Now compare sorted array with original array and find number of mismatched indices. If its 2, we can swap them and make it sorted. If its more than 2, we can never make it sorted in one swap.
Also consider a case when an array is already sorted. If there is at least one element whose frequency is at least 2, we can swap its two occurances and keep the array sorted. If there is no such element, the array will become unsorted after one swap.

Complexity Analysis

Time Complexity: \mathcal{O}(n \log{n}) for sorting.
Space Complexity: \mathcal{O}(n) for making a copy.

Codes

Setter's Code
/***************************************************

@author: vichitr
Compiled On: 13th Mar 2021

*****************************************************/

#include<bits/stdc++.h>
#define pb push_back
#define mp make_pair
typedef long long int ll;
using namespace std;

int main()
{
    int t; cin>>t;
  	assert(t <= 1e5);
  	int tot = 0;
    while(t--){
        int n; cin>>n;
	  	assert(n <= 1e5);
	  	tot += n;
        vector<int> a(n);
        for(int i=0;i<n;++i){
            cin>>a[i];
		  	assert(a[i] >= 0 and a[i] <= 1e9);
		}
        vector<int> b = a;
        sort(b.begin(), b.end());
        int cnt = 0, f = 0;
        for(int i=0;i<n;i++)
        {
            if(b[i] != a[i])
                cnt++;
        }
        for(int i=1;i<n;i++){
            if(a[i] == a[i-1]){
                f = 1;
                break;
            }
        }
        if(cnt == 2){
            cout<<"YES\n";
        }
        else if(cnt == 0 and f == 1){
            cout<<"YES\n";
        }
        else
            cout<<"NO\n";

    }
  	assert(tot <= 5e5);
    return 0;
}
Tester's Code
/*
	katana_handler
*/
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace std;
using namespace __gnu_pbds;
#define ordered_set_pll tree <pll, null_type,greater<pll>, rb_tree_tag,tree_order_statistics_node_update>
#define ordered_set tree <ld, null_type,greater<ld>, rb_tree_tag,tree_order_statistics_node_update>
/*
 * query 1 order_of_key (k) : Number of items strictly smaller than k .
 * query 2 find_by_order(k) : K-th element in the set (counting from zero).
 * less<ll> means query 1 will return numbers strictly less than k
 * greater<ll> means query 1 will return numbers strictly greater than k
*/
#define			pb            	push_back
#define			pf              push_front
#define         MOD             1000000007
#define			popb            pop_back
#define         popf            pop_front
#define         len(x)          (ll)x.size()    
#define         MAXN            100001
#define         mp              make_pair
#define         endl            '\n'
#define         ff              first
#define         ss              second
#define         tt              third
#define         mapcl           map<char,ll>
#define         mapll           map<ll,ll>
#define         cmp             complex<double>
#define         pi              3.141592653589793238462643383279502884197169399375105820974944592307816406286
#define         inf             LLONG_MAX
#define         flush           fflush(stdout)
#define         vll             vector<ll>
#define         all(v)          v.begin(),v.end()
#define         fr(i,z,n)       for(ll i=z;i<n;i++)
#define         sqrt            sqrtl
#define         cbrt            cbrtl
typedef         long long       ll;
typedef         pair<ll,ll>     pll;
typedef         long double     ld;
typedef			unsigned long long ull;
ll              power(ll a,ll b,ll mod)         {if(b==0)return 1;ll tmp=power(a,b/2,mod);tmp=(tmp*tmp)%mod;if(b&1)tmp=(tmp*a)%mod;return tmp;}
ll              mandist(pll a,pll b)            {return abs(a.ff-b.ff)+abs(a.ss-b.ss);}
ld             	dist2d(pll a,pll b)             {ll x=abs(a.ff-b.ff);ll y=abs(a.ss-b.ss);return (sqrt((x*x)+(y*y)));}
bool        	coll(pll p1,pll p2,pll p3)      {if((p3.ss-p2.ss)*(p2.ff-p1.ff)==(p2.ss-p1.ss)*(p3.ff-p2.ff))return true;else return false;}
void        	print(auto x)                   {cout<<x<<endl;}
void        	pv(vector<auto> v)              {for(ll i=0;i<(ll)v.size();i++)cout<<v[i]<<" ";cout<<endl;}
void        	pv1(vector<auto> v)             {for(ll i=1;i<(ll)v.size();i++)cout<<v[i]<<" ";cout<<endl;}
void        	pset(set<auto> v)               {for(auto x:v)cout<<x<<" ";cout<<endl;}
void        	pvsl(vector<auto> v)            {for(ll i=0;i<(ll)v.size();i++)cout<<v[i]<<" ";}
void        	in(vector<auto> &v, ll n)       {v.resize(n);for(ll i=0;i<n;i++)cin>>v[i];}
void        	in1(vector<auto> &v, ll n)      {v.resize(n+1);for(ll i=1;i<=n;i++)cin>>v[i];}
void        	sorta(auto &v)                  {sort(v.begin(),v.end());}
void        	sortd(auto &v)                  {sort(v.begin(),v.end(),greater <>());}		
void        	fast()                          {ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);}
//all variables and functions below this line
ll globalsum=0;
void solve()
{
	ll n;
	cin>>n;
	globalsum+=n;
	assert(n<=1e5 && n>=1);
	vll v;
	in(v,n);
	for(ll x:v)
	assert(x>=0 && x<=1e9);
	vll a=v;
	sorta(a);
	ll sum=0;
	for(ll i=0;i<n;i++)
	if(a[i]!=v[i])
	sum++;
	if(sum==2)
	cout<<"YES"<<endl;
	else if(sum==0)
	{
		mapll m;
		for(ll x:a)
		m[x]++;
		for(pll x:m)
		if(x.ss>1)
		{
			cout<<"YES"<<endl;
			return;
		}
		cout<<"NO"<<endl;
	}
	else
	cout<<"NO"<<endl;
}
int main()
{
	//srand(time(NULL));
	//fast();
	ll t=1;
	cin>>t;
	for(ll i=1;i<=t;i++)
	{
		//cerr<<t<<endl;
		//cout<<"Case #"<<i<<": ";
		solve();
	}
	assert(globalsum<=5e5);
}

If you have used any other approach, share your approach in comments!
If you have any doubts, ask them in comments!

2 Likes