如何在 TensorFlow C++ API 中获取占位符大小

How to get placeholder size in TensorFlow C++ API?

本文关键字:获取 占位符 API TensorFlow C++      更新时间:2023-10-16

我想使用C++来加载TensorFlow模型。我想知道模型输入的大小,这是模型中的占位符。

我用谷歌搜索这个问题,但我只是在堆栈溢出中找到了这个链接:

C++相当于python:tf。Graph.get_tensor_by_name() 在 Tensorflow 中?

虽然我可以得到节点,但是张量流文档没有告诉我如何访问节点的大小。那么有没有人知道这件事呢?

非常感谢!

好的,经过多次尝试。我找到了一个解决方法,它可能很棘手,但效果很好。

首先,我们可以使用以下代码获取占位符节点:

GraphDef mygd = graph_def.graph_def();
for (int i = 0; i < mygd.node_size(); i++)
{
    if (mygd.node(i).name() == input_name)
    {
        auto node = mygd.node(i);
    }
}

然后通过 NodeDef.pd.h(tensorflow/core/framework/node_def.pb.h),我们可以通过如下代码获取 AttrValue:

auto attr = node.attr();

然后通过 attr_value.cc(tensorflow/core/framework/attr_value.cc),我们可以通过如下代码获取形状 attr 值:

tensorflow::AttrValue shape = attr["shape"];

形状 AttrValue 是用于存储形状信息的结构。我们可以通过函数 SummarizeAttrValue in tensorflow/core/framework/attr_value_util.h 获取详细信息。

string size_summary = SummarizeAttrValue(shape);

然后我们可以得到形状的字符串格式,如下所示:

[?,1024]