C++调用和定义不匹配

C++ call and definition mismatch

本文关键字:不匹配 定义 调用 C++      更新时间:2023-10-16

我正在查看此代码块 https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/profiler.cpp#L141

pushCallback(
[config](const RecordFunction& fn) {
auto* msg = (fn.seqNr() >= 0) ? ", seq = " : "";
if (config.report_input_shapes) {
std::vector<std::vector<int64_t>> inputSizes;
inputSizes.reserve(fn.inputs().size());
for (const c10::IValue& input : fn.inputs()) {
if (!input.isTensor()) {
inputSizes.emplace_back();
continue;
}
const at::Tensor& tensor = input.toTensor();
if (tensor.defined()) {
inputSizes.push_back(input.toTensor().sizes().vec());
} else {
inputSizes.emplace_back();
}
}
pushRangeImpl(fn.name(), msg, fn.seqNr(), std::move(inputSizes));
} else {
pushRangeImpl(fn.name(), msg, fn.seqNr(), {});
}
},
[](const RecordFunction& fn) {
if (fn.getThreadId() != 0) {
// If we've overridden the thread_id on the RecordFunction, then find
//  the eventList that was created for the original thread_id. Then,
// record the end event on this list so that the block is added to
// the correct list, instead of to a new list. This should only run
// when calling RecordFunction::end() in a different thread.
if (state == ProfilerState::Disabled) {
return;
} else {
std::lock_guard<std::mutex> guard(all_event_lists_map_mutex);
const auto& eventListIter =
all_event_lists_map.find(fn.getThreadId());
TORCH_INTERNAL_ASSERT(
eventListIter != all_event_lists_map.end(),
"Did not find thread_id matching ",
fn.getThreadId());
auto& eventList = eventListIter->second;
eventList->record(
EventKind::PopRange,
StringView(""),
fn.getThreadId(),
state == ProfilerState::CUDA);
}
} else {
popRange();
}
},
config.report_input_shapes);

这只有三个参数。但是 pushCallback 的定义似乎在这个位置 https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/record_function.cpp#L35 并采用四个参数。

void pushCallback(
RecordFunctionCallback start,
RecordFunctionCallback end,
bool needs_inputs,
bool sampled) {
start_callbacks.push_back(std::move(start));
end_callbacks.push_back(std::move(end));
if (callback_needs_inputs > 0 || needs_inputs) {
++callback_needs_inputs;
}
is_callback_sampled.push_back(sampled);
if (sampled) {
++num_sampled_callbacks;
}
}

我不知道为什么该函数调用可以以这种方式工作。

如果你看一下标题,你会发现它是用 4 个参数声明的,其中最后三个参数有默认值:

TORCH_API void pushCallback(
RecordFunctionCallback start,
RecordFunctionCallback end = [](const RecordFunction&){},
bool needs_inputs = false,
bool sampled = false);

默认参数仅出现在声明上,而不显示在定义上。