使用 STL 的 Dijkstra 算法实现

Dijkstra's Algo Implementation using STL

本文关键字:算法 实现 Dijkstra STL 使用      更新时间:2023-10-16
#include<bits/stdc++.h>
using namespace std;
#define ll long long
vector<pair<ll,ll> >v[100005];
ll dis[100005];
bool visited[100005];
multiset<pair <ll,ll> > s;
int main(){
    ll n,m,from,next,weight,i;
    cin>>n>>m;
    for(i=1;i<=n;i++){
        v[i].clear();
        dis[i]=2e9;
    }
    for(i=1;i<=m;i++){
        cin>>from>>next>>weight;
        v[from].push_back(make_pair(next,weight));
        v[next].push_back(make_pair(from,weight));
    }
    dis[1]=0;
    s.insert({0,1});
    memset(visited,false,sizeof(visited));
    while(!s.empty()){
        pair<ll,ll>p= *s.begin();
        s.erase(s.begin());
        ll x=p.second;
        ll wei=p.first;
        if(visited[x]) continue;
        for(i=0;i<v[x].size();i++){
            ll e=v[x][i].first;
            ll w=v[x][i].second;
            if(dis[x]+w < dis[e]){
                dis[e]=dis[x]+w;
                s.insert({dis[e],e});
            }
        }
    }
    for(i=2;i<=m;i++)
     cout<<dis[i]<<" ";
}

我有一个Dijkstra's Algo的c++实现,但我猜这并不是在所有情况下都能正常工作(更大的测试用例)。有人能帮我修这个吗?我错过了什么或没有正确执行。代码输出每个顶点到源顶点的最小距离(即。1) .

您从不写入visited数组。因此,边缘可能被扫描多次。简单的修复:在if(visited[x]) continue;:

后面添加一行
visited[x] = true;

这是我在O(N)个图中求解的解:# include# include# include

     typedef long long ll;
    void fs_int(int *x) {
        register int c = getchar_unlocked();
        *x = 0;
        int neg = 0;
        for(; ((c<48 || c>57) && c != '-'); c = getchar_unlocked());
        if(c=='-') {
            neg = 1;
            c = getchar_unlocked();
        }
        for(; c>47 && c<58 ; c = getchar_unlocked()) {
            *x = (*x<<1) + (*x<<3) + c - 48;
        }
        if(neg)
            *x = -(*x);
    } 
    typedef struct {
        int next;
        int val;
        int d;
    }List;
    typedef struct 
    {
        int parent;
        int shrt;
        ll count;
        int on_reg;
        int ch;
    } Node;
    #define MOD 1000000007
    ll get_sum(Node *tr,List *l)
    {
        Node *t, *t2;
        int i,j,n=0,fix;
        ll result;
        static int *reg=NULL,sz=1000;
        if (!reg)
            reg=malloc(sizeof(int)*sz);
        reg[n++]=1;
        int  cur_d;
        while(n)
        {
            ///fix is the limit for the for, it is the shortname of "for ix" :
            // from 0 to fix there are the old values, from fix to n there are the new ones
            fix=n;   
            for (i=0;i<fix;i++)
            {
               //the better way to reduce the complexity is shift the last item to the current one
                t=&tr[reg[i]];
                reg[i--]=reg[--fix];
                reg[fix]=reg[--n];
                t->on_reg=0;
                ///this scores all the edges from departing from this node
                ///the criteria to avoid propagation is the key of the program
                for (j=t->ch;j;j=l[j].next)
                {   
                    if (l[j].val==1) //avoid the root
                        continue;
                    t2=&tr[l[j].val]; //store in some comfortable variable
                    cur_d=t->shrt+l[j].d; 
                    if (t2->shrt!=0 && t2->shrt< cur_d ) ///if my path is heaviest nothing to do
                        continue;
                    else if (t2->shrt ==cur_d) //I found an item with same weight. It was required to count them
                        t2->count++;
                    else if (t2->shrt==0 || t2->shrt>cur_d) //found a unexplored item or my path is lighter
                    {
                        t2->shrt=cur_d;
                        t2->count=1;
                        if (!t2->on_reg) //if not already in the reg, I insert it inside
                        {
                            if (n>=sz)
                            {
                                sz<<=1;
                                reg=realloc(reg, sizeof(int)*sz);
                            }
                            reg[n++]=l[j].val; //at position n
                            t2->on_reg=1;
                        }
                    }
                }
           /* printf ("reg: ");
            for (k=0;k<n;k++)
                printf ("%d ",reg[k]);
                printf ("n");*/
            }
        }
        //printf ("n");
        return result;

    }
    typedef long long ll;
    void set_depth(Node *tr, List *l, int rt,int cd,int parent)
    {
        int i;
        tr[rt].parent=parent;
        for (i=tr[rt].ch;i;i=l[i].next)
            if (l[i].val== parent )
                continue;
            else 
                set_depth(tr,l,l[i].val,cd+1,rt);
    }
    int main ()
    {
        int t,n,q,i,u,v,d;
        fs_int(&t);
        int il=1;
        Node tr[100005];
        List l[200005];
        List *tl;
        while (t--)
        {
            fs_int(&n);
            fs_int(&q);
            il=1;
            memset(tr,0,sizeof(tr));
            memset(l,0,sizeof(l));
            for (i=0;i<q;i++)
            {
                fs_int(&u);
                fs_int(&v);
                fs_int(&d);
                tl=&l[il];
                tl->next=tr[u].ch;
                tl->val=v;
                tl->d=d;
                tr[u].ch=il++;

                tl=&l[il];
                tl->next=tr[v].ch;
                tl->val=u;
                tl->d=d;
                tr[v].ch=il++;
            }
           //set_depth(tr,l,1,0,0);
           // print(tr,l,1,0,0);
           get_sum(tr,l);
           ll res=1;
            for (i=2;i<=n;i++)
            {

                res= ( (res%MOD) *(tr[i].count%MOD) )%MOD;
            }   
            printf ("%lldn",res);
        }
        return 0;
    }

您感兴趣的函数是函数get_sum()。这是一个广度优先的搜索,在一个图中意味着沿着同心圆检查,这允许你避免无用的传播。它将这些值存储在一个名为reg的数组中。在每一步你都要向前检查。关于效率,你可以在Ways大赛上自己检查。