About the Problem
Setter(s) | Vichitr Gandas |
Tester(s) | Mayank Pugalia |
Difficulty | Cakewalk |
Topics | Observation, Sorting |
Practice | Link |
Contest | Link |
Solution Idea
Make a copy of the original array and sort it. Now compare sorted array with original array and find number of mismatched indices. If its 2, we can swap them and make it sorted. If its more than 2, we can never make it sorted in one swap.
Also consider a case when an array is already sorted. If there is at least one element whose frequency is at least 2, we can swap its two occurances and keep the array sorted. If there is no such element, the array will become unsorted after one swap.
Complexity Analysis
Time Complexity: \mathcal{O}(n \log{n}) for sorting.
Space Complexity: \mathcal{O}(n) for making a copy.
Codes
Setter's Code
/***************************************************
@author: vichitr
Compiled On: 13th Mar 2021
*****************************************************/
#include<bits/stdc++.h>
#define pb push_back
#define mp make_pair
typedef long long int ll;
using namespace std;
int main()
{
int t; cin>>t;
assert(t <= 1e5);
int tot = 0;
while(t--){
int n; cin>>n;
assert(n <= 1e5);
tot += n;
vector<int> a(n);
for(int i=0;i<n;++i){
cin>>a[i];
assert(a[i] >= 0 and a[i] <= 1e9);
}
vector<int> b = a;
sort(b.begin(), b.end());
int cnt = 0, f = 0;
for(int i=0;i<n;i++)
{
if(b[i] != a[i])
cnt++;
}
for(int i=1;i<n;i++){
if(a[i] == a[i-1]){
f = 1;
break;
}
}
if(cnt == 2){
cout<<"YES\n";
}
else if(cnt == 0 and f == 1){
cout<<"YES\n";
}
else
cout<<"NO\n";
}
assert(tot <= 5e5);
return 0;
}
Tester's Code
/*
katana_handler
*/
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace std;
using namespace __gnu_pbds;
#define ordered_set_pll tree <pll, null_type,greater<pll>, rb_tree_tag,tree_order_statistics_node_update>
#define ordered_set tree <ld, null_type,greater<ld>, rb_tree_tag,tree_order_statistics_node_update>
/*
* query 1 order_of_key (k) : Number of items strictly smaller than k .
* query 2 find_by_order(k) : K-th element in the set (counting from zero).
* less<ll> means query 1 will return numbers strictly less than k
* greater<ll> means query 1 will return numbers strictly greater than k
*/
#define pb push_back
#define pf push_front
#define MOD 1000000007
#define popb pop_back
#define popf pop_front
#define len(x) (ll)x.size()
#define MAXN 100001
#define mp make_pair
#define endl '\n'
#define ff first
#define ss second
#define tt third
#define mapcl map<char,ll>
#define mapll map<ll,ll>
#define cmp complex<double>
#define pi 3.141592653589793238462643383279502884197169399375105820974944592307816406286
#define inf LLONG_MAX
#define flush fflush(stdout)
#define vll vector<ll>
#define all(v) v.begin(),v.end()
#define fr(i,z,n) for(ll i=z;i<n;i++)
#define sqrt sqrtl
#define cbrt cbrtl
typedef long long ll;
typedef pair<ll,ll> pll;
typedef long double ld;
typedef unsigned long long ull;
ll power(ll a,ll b,ll mod) {if(b==0)return 1;ll tmp=power(a,b/2,mod);tmp=(tmp*tmp)%mod;if(b&1)tmp=(tmp*a)%mod;return tmp;}
ll mandist(pll a,pll b) {return abs(a.ff-b.ff)+abs(a.ss-b.ss);}
ld dist2d(pll a,pll b) {ll x=abs(a.ff-b.ff);ll y=abs(a.ss-b.ss);return (sqrt((x*x)+(y*y)));}
bool coll(pll p1,pll p2,pll p3) {if((p3.ss-p2.ss)*(p2.ff-p1.ff)==(p2.ss-p1.ss)*(p3.ff-p2.ff))return true;else return false;}
void print(auto x) {cout<<x<<endl;}
void pv(vector<auto> v) {for(ll i=0;i<(ll)v.size();i++)cout<<v[i]<<" ";cout<<endl;}
void pv1(vector<auto> v) {for(ll i=1;i<(ll)v.size();i++)cout<<v[i]<<" ";cout<<endl;}
void pset(set<auto> v) {for(auto x:v)cout<<x<<" ";cout<<endl;}
void pvsl(vector<auto> v) {for(ll i=0;i<(ll)v.size();i++)cout<<v[i]<<" ";}
void in(vector<auto> &v, ll n) {v.resize(n);for(ll i=0;i<n;i++)cin>>v[i];}
void in1(vector<auto> &v, ll n) {v.resize(n+1);for(ll i=1;i<=n;i++)cin>>v[i];}
void sorta(auto &v) {sort(v.begin(),v.end());}
void sortd(auto &v) {sort(v.begin(),v.end(),greater <>());}
void fast() {ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);}
//all variables and functions below this line
ll globalsum=0;
void solve()
{
ll n;
cin>>n;
globalsum+=n;
assert(n<=1e5 && n>=1);
vll v;
in(v,n);
for(ll x:v)
assert(x>=0 && x<=1e9);
vll a=v;
sorta(a);
ll sum=0;
for(ll i=0;i<n;i++)
if(a[i]!=v[i])
sum++;
if(sum==2)
cout<<"YES"<<endl;
else if(sum==0)
{
mapll m;
for(ll x:a)
m[x]++;
for(pll x:m)
if(x.ss>1)
{
cout<<"YES"<<endl;
return;
}
cout<<"NO"<<endl;
}
else
cout<<"NO"<<endl;
}
int main()
{
//srand(time(NULL));
//fast();
ll t=1;
cin>>t;
for(ll i=1;i<=t;i++)
{
//cerr<<t<<endl;
//cout<<"Case #"<<i<<": ";
solve();
}
assert(globalsum<=5e5);
}
If you have used any other approach, share your approach in comments!
If you have any doubts, ask them in comments!