正确的方法来创建thread_safe shared_ptr没有锁

Correct way to create thread_safe shared_ptr without a lock?

本文关键字:shared safe ptr thread 方法 创建      更新时间:2023-10-16

我正在尝试创建一个线程安全的shared_ptr类。我的用例是shared_ptr属于类的对象,并且行为有点像单例(CreateIfNotExist函数可以由任何线程在任何时间点运行)。

本质上,如果指针为空,第一个设置该值的线程获胜,并且同时创建该指针的所有其他线程都使用获胜线程的值。

这是我到目前为止所做的(注意,唯一有问题的函数是CreateIfNotExist()函数,其余的用于测试目的):

#include <memory>
#include <iostream>
#include <thread>
#include <vector>
#include <mutex>
struct A {
    A(int a) : x(a) {}
    int x;
};
struct B {
    B() : test(nullptr) {}
    void CreateIfNotExist(int val) {
        std::shared_ptr<A> newPtr = std::make_shared<A>(val);
        std::shared_ptr<A> _null = nullptr;
        std::atomic_compare_exchange_strong(&test, &_null, newPtr);
    }
    std::shared_ptr<A> test;
};
int gRet = -1;
std::mutex m;
void Func(B* b, int val) {
    b->CreateIfNotExist(val);
    int ret =  b->test->x;
    if(gRet == -1) {
        std::unique_lock<std::mutex> l(m);
        if(gRet == -1) {
            gRet = ret;
        }
    }
    if(ret != gRet) {
        std::cout << " FAILED " << std::endl;
    }
}
int main() {
    B b;
    std::vector<std::thread> threads;
    for(int i = 0; i < 10000; ++i) {
        threads.clear();
        for(int i = 0; i < 8; ++i) threads.emplace_back(&Func, &b, i);
        for(int i = 0; i < 8; ++i) threads[i].join();
    }
}

这样做正确吗?是否有更好的方法来确保所有线程调用CreateIfNotExist()在同一时间都使用相同的shared_ptr?

可以这样写:

struct B {
  void CreateIfNotExist(int val) {
    std::call_once(test_init,
                   [this, val](){test = std::make_shared<A>(val);});
  }
  std::shared_ptr<A> test;
  std::once_flag test_init;
};