bzoj 4456: [Zjoi2016] traveller division + shortest path

meaning of the title

Little y is traveling to a new city. She found that the layout of the city is grid like, that is, there are n roads from east to west and M roads from south to north. The two roads intersect to form n × m intersections (i,j)(1 ≤ i ≤ n,1 ≤ J ≤ m). She found that different roads have different conditions, so it takes different time to pass through different intersections. It is found that time r(i,j) is needed from intersection (i,j) to intersection (i,j+1), and time c(i,j) is needed from intersection (i,j) to intersection (i+1,j). Note that the road here is two-way. Little y has q questions. She wants to know the minimum time from intersection (x1,y1) to intersection (x2,y2).
n*m<=2*10^4,q<=10^5

Analysis

It's obviously not good to run directly. Consider divide and conquer. Cut the current larger side into two parts, enumerate each point on the half axis as the middle point, run the shortest path to all points in the rectangle, and then update the answers to the queries that both points are in the rectangle, and divide and conquer.
There was A timeout point on uoj, but A on bzoj.

Code

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;

const int N=20005;
const int M=100005;
const int inf=1000000000;

int n,m,cnt,last[N],dis[N],ans[M],tot;
bool vis[N],arr[N];
struct data{int x1,x2,y1,y2,id;}q[M],tmp[M];
priority_queue<pair<int,int> > que;
struct edge{int to,next,w;}e[N*10];

int read()
{
    int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

int point(int x,int y)
{
    return (y-1)*n+x;
}

void addedge(int u,int v,int w)
{
    e[++cnt].to=v;e[cnt].w=w;e[cnt].next=last[u];last[u]=cnt;
    e[++cnt].to=u;e[cnt].w=w;e[cnt].next=last[v];last[v]=cnt;
}

void init()
{
    n=read();m=read();int w;
    for (int i=1;i<=n;i++)
        for (int j=1;j<m;j++)
            w=read(),addedge(point(i,j),point(i,j+1),w);
    for (int i=1;i<n;i++)
        for (int j=1;j<=m;j++)
            w=read(),addedge(point(i,j),point(i+1,j),w);
    tot=read();
    for (int i=1;i<=tot;i++) q[i].x1=read(),q[i].y1=read(),q[i].x2=read(),q[i].y2=read(),q[i].id=i;
    for (int i=1;i<=tot;i++) ans[i]=inf;
}

void dij(int s)
{
    dis[s]=0;que.push(make_pair(0,s));
    while (!que.empty())
    {
        pair<int,int> u=que.top();que.pop();
        while (!que.empty()&&vis[u.second]) u=que.top(),que.pop();
        if (vis[u.second]) break;
        int x=u.second;vis[x]=1;
        for (int i=last[x];i;i=e[i].next)
            if (arr[e[i].to]&&dis[x]+e[i].w<dis[e[i].to])
            {
                dis[e[i].to]=dis[x]+e[i].w;
                que.push(make_pair(-dis[e[i].to],e[i].to));
            }
    }
}

void solve(int d,int u,int l,int r,int L,int R)
{
    if (L>R) return;
    for (int i=d;i<=u;i++)
        for (int j=l;j<=r;j++)
            arr[point(i,j)]=1;
    if (r-l+1>=u-d+1)
    {
        int mid=(l+r)/2;
        for (int i=d;i<=u;i++)
        {
            for (int x=d;x<=u;x++)
                for (int y=l;y<=r;y++)
                    dis[point(x,y)]=inf,vis[point(x,y)]=0;
            dij(point(i,mid));
            for (int j=L;j<=R;j++) ans[q[j].id]=min(ans[q[j].id],dis[point(q[j].x1,q[j].y1)]+dis[point(q[j].x2,q[j].y2)]);
        }
        if (l==r) return;
        int s1=L-1,s2=R+1;
        for (int i=L;i<=R;i++)
            if (q[i].y1<=mid&&q[i].y2<=mid) tmp[++s1]=q[i];
            else if (q[i].y1>mid&&q[i].y2>mid) tmp[--s2]=q[i];
        for (int i=L;i<=R;i++) q[i]=tmp[i];
        solve(d,u,l,mid,L,s1);solve(d,u,mid+1,r,s2,R);
    }
    else
    {
        int mid=(d+u)/2;
        for (int i=l;i<=r;i++)
        {
            for (int x=d;x<=u;x++)
                for (int y=l;y<=r;y++)
                    dis[point(x,y)]=inf,vis[point(x,y)]=0;
            dij(point(mid,i));
            for (int j=L;j<=R;j++) ans[q[j].id]=min(ans[q[j].id],dis[point(q[j].x1,q[j].y1)]+dis[point(q[j].x2,q[j].y2)]);
        }
        if (u==d) return;
        int s1=L-1,s2=R+1;
        for (int i=L;i<=R;i++)
            if (q[i].x1<=mid&&q[i].x2<=mid) tmp[++s1]=q[i];
            else if (q[i].x1>mid&&q[i].x2>mid) tmp[--s2]=q[i];
        for (int i=L;i<=R;i++) q[i]=tmp[i];
        solve(d,mid,l,r,L,s1);solve(mid+1,u,l,r,s2,R);
    }
    for (int i=d;i<=u;i++)
        for (int j=l;j<=r;j++)
            arr[point(i,j)]=0;
}

int main()
{
    init();
    solve(1,n,1,m,1,tot);
    for (int i=1;i<=tot;i++) printf("%d\n",ans[i]);
    return 0;
}

Posted on Fri, 01 May 2020 20:17:00 -0700 by DanAuito