HDU 1394 Minimum Inversion Number(线段树,逆序数)
题目大意:
给一个n个数的序列a1, a2, ..., an ,这些数的范围是0~n-1, 可以把前面m个数移动到后面去,形成新序列:
a1, a2, ..., an-1, an (where m = 0 - the initial seqence)
a2, a3, ..., an, a1 (where m = 1)
a3, a4, ..., an, a1, a2 (where m = 2)
...
an, a1, a2, ..., an-1 (where m = n-1)
求这些序列中,逆序数最少的是多少?
分析与总结:
(1) 初看这题时,觉得眼熟,于是翻看了交题记录,原来去年做过 = =,不过那时是直接暴力求出逆序数水过的...
然后就想这样用线段树来优化。所谓的求逆序数,其实就是序列中的每一个数,它前面的比他大的数的数量之和。
用线段树记录下各个数,区间的值【a,b】表示数字a~b的已经出现了多少次,所以对于ai,只需要查询【ai, n】有多少个(ai之前的比ai大的有多少个),就是代表ai有多少个逆序数了。
(2) 求出a1, a2, ..., an-1, an的逆序数之后,就可以递推求出其他序列的逆序数。 假设要把a1移动到an之后,那么我们把这个过程拆分成两步:
1. 把a1去除掉。通过观察可以发现,(a1-1)是0~n-1中比a1小的数字的个数,由于a1在序列的第一个所以a1之后共有(a1-1)个比a1小,所以形成了(a1-1)对逆序数,当去除掉a1时,原序列的逆序数总数也就减少了(a1-1)个逆序数。
2. 把a1加到an之后。0~n-1中,比a1大的数共有(n-a1)个数,由于a1现在在最后一个,也就是它前面共有(n-a1)个数比它大,即增加了(n-a1)对逆序数。
综合1,2两步, 设原序列逆序数为sum, 当把原序列第一个移动到最后位置时,逆序数变为:sum = sum-(ai-1)+(n-ai);
代码:
[cpp]
#include<iostream>
#include<cstdio>
#include<cstring>
#define lson(x) (x<<1)
#define rson(x) (lson(x)|1)
using namespace std;
const int MAX_NODE = 5005 << 2;
int arr[MAX_NODE];
struct node{
int left, right;
int num;
int mid(){return (left+right)>>1;}
bool buttom(){return left==right;}
};
class SegTree{
public:
void build(int cur,int left,int right){
t[cur].left = left;
t[cur].right = right;
if(left == right){
t[cur].num = 0;
return;
}
int m = t[cur].mid();
build(lson(cur),left,m);
build(rson(cur),m+1,right);
push_up(cur);
}
void update(int cur,int data){
++t[cur].num;
if(t[cur].buttom()){
return;
}
int m = t[cur].mid();
if(data <= m)
update(lson(cur),data);
else
update(rson(cur),data);
}
int query(int cur,int left,int right){
if(t[cur].left==left && t[cur].right==right){
return t[cur].num;
}
int m=t[cur].mid();
if(right <= m)
return query(lson(cur),left,right);
else if(left > m)
return query(rson(cur),left,right);
else
return query(lson(cur),left,m)+query(rson(cur),m+1,right);
}
private:
void push_up(int cur){
t[cur].num = t[lson(cur)].num+t[rson(cur)].num;
}
node t[MAX_NODE];
};
SegTree st;
int main(){
int n,x;
while(~scanf("%d",&n)){
st.build(1,1,n);
int sum=0;
for(int i=0; i<n; ++i){
scanf("%d",&arr[i]);
++arr[i];
int tmp = st.query(1,arr[i],n);
sum += tmp;
st.update(1,arr[i]);
}
int _min = sum;
for(int i=0; i<n-1; ++i){
sum = sum-(arr[i]-1)+(n-arr[i]);
if(sum < _min) _min=sum;
}
printf("%d/n",_min);
}
return 0;
}