你能在递归 lambda 中捕获引用吗?

Can you capture a reference in a recursive lambda?

本文关键字:引用 递归 lambda      更新时间:2023-10-16

我已经完全解决了HackerRank(https://www.hackerrank.com/challenges/ctci-recursive-staircase/problem(上的一个特定问题,使用带有记忆的递归解决方案:

std::map<int, int> memoize;
int davis_staircase(int n) {
    if (n == 0) {
        return 1;
    } else if (n < 0) {
        return 0;
    }
    auto find = memoize.find(n);
    if (find != memoize.end()) {
        return find->second;
    }
    int num_steps = davis_staircase(n - 1) + davis_staircase(n - 2) + davis_staircase(n - 3);
    memoize[n] = num_steps;
    return num_steps;
}

我想隐藏我用作查找的全局std::map(不使用类(,并认为我会尝试创建一个可以递归调用的 lambda,也可以通过引用捕获缓存/映射。我尝试了以下方法:

int davis_staircase_2(int n) {
    std::map<int, int> memo;
    //auto recurse = [&memo](int n) -> int {                    // attempt (1)
    //std::function<int(int)> recurse = [&memo](int n) -> int { // attempt (2)
    std::function<int(std::map<int, int>&, int)> recurse = [](std::map<int, int>& memo, int n) -> int { // attempt (3)
        if (n == 0) {
            return 1;
        } else if (n < 0) {
            return 0;
        }
        auto find = memo.find(n);
        if (find != memo.end()) {
            return find->second;
        }
        //int num_steps = recurse(n - 1) + recurse(n - 2) + recurse(n - 3); // attempt (1) or (2)
        int num_steps = recurse(memo, n - 1) + recurse(memo, n - 2) + recurse(memo, n - 3); // attempt (3)
        memo[n] = num_steps;
        return num_steps;
    };
    //return recurse(n); // attempt (1) or (2)
    return recurse(memo, n); // attempt (3)
}

我在上面交错了 3 次略有不同的尝试,但我无法编译任何尝试。我想做的事情可能吗?

我在MacOS上使用clang:

Apple LLVM version 10.0.0 (clang-1000.10.44.4)
Target: x86_64-apple-darwin18.2.0
Thread model: posix

你忘了捕获recurse,所以你的代码可能是

std::function<int(int)> recurse = [&recurse, &memo](int n) -> int { // attempt (2)

std::function<int(int)> recurse = [&](int n) -> int { // attempt (2)

同样,对于// attempt (3)

std::function<int(std::map<int, int>&, int)> recurse = [&recurse](std::map<int, int>& memo, int n) -> int { // attempt (3)

// attempt (1)不能按原样修复,因为在定义recurse类型之前就使用了它。

要在没有 std::function 的情况下执行此操作,您可以使用 Y 组合器(对于泛型 lambda 需要 C++14(:

int davis_staircase_2(int n) {
    std::map<int, int> memo;
    auto recurse = [&memo](auto self, int n) -> int { // attempt (4)
        if (n == 0) {
            return 1;
        } else if (n < 0) {
            return 0;
        }
        auto find = memo.find(n);
        if (find != memo.end()) {
            return find->second;
        }
        int num_steps = self(self, n - 1) + self(self, n - 2) + self(self, n - 3); // attempt (4)
        memo[n] = num_steps;
        return num_steps;
    };
    return recurse(recurse, n); // attempt (4)
}

你不需要递归函数...

int stepPerms(int n) {
  std::map<int, int> memoize;
  memoize[-2] = 0;
  memoize[-1] = 0;
  memoize[0] = 1;
 for(int i=1;i<=n;++i)
 {
   memoize[i] = memoize[i - 1] + memoize[i - 2] + memoize[i-3];
 }
 return memoize[n];
}

你可以在没有类型擦除的情况下执行递归 lambda(std::function(。这是使用通用 lambda 的方式:

auto recurse = [](auto lambda) {
    return [lambda](auto&&... args) {
        return lambda(lambda, std::forward<decltype(args)>(args)...);
    };
};
auto my_recursive_lambda = recurse([](auto self, std::map<int, int>& memo, int n) {
    if (n == 0) {
        return 1;
    } else if (n < 0) {
        return 0;
    }
    auto find = memo.find(n);
    if (find != memo.end()) {
        return find->second;
    }
    int num_steps = self(self, memo, n - 1) + self(self, memo, n - 2) + self(self, memo, n - 3);
    memo[n] = num_steps;
    return num_steps;
});
my_recursive_lambda(memo, n); // magic!

如果你真的需要这个 c++11,你将需要std::function

auto recurse = std::function<int(std::map<int, int>&, int)>{};
recurse = [&recurse](std::map<int, int>& memo, int n) {
    // same as you tried.
}

或者,如果您放弃了便利性,则可以手动滚动您的 lambda 类型:

struct {
    auto operator()(std::map<int, int>& memo, int n) -> int {
        auto&& recurse = *this;
        if (n == 0) {
            return 1;
        } else if (n < 0) {
            return 0;
        }
        auto find = memo.find(n);
        if (find != memo.end()) {
            return find->second;
        }
        //int num_steps = recurse(n - 1) + recurse(n - 2) + recurse(n - 3); // attempt (1) or (2)
        int num_steps = recurse(memo, n - 1) + recurse(memo, n - 2) + recurse(memo, n - 3); // attempt (3)
        memo[n] = num_steps;
        return num_steps;
    }
} recurse{};