如何保存RNN的状态以从TF中的图中提交

How do I save the states of an RNN to file from a Graph in TF?

本文关键字:TF 何保存 提交 状态 保存 RNN      更新时间:2023-10-16

以下代码来自TensorFlow服务API:

// Implementation of Predict using the SavedModel SignatureDef format.
Status SavedModelPredict(const RunOptions& run_options, ServerCore* core,
                         const PredictRequest& request,
                         PredictResponse* response) {
  // Validate signatures.
  ServableHandle<SavedModelBundle> bundle;
  TF_RETURN_IF_ERROR(core->GetServableHandle(request.model_spec(), &bundle));
  const string signature_name = request.model_spec().signature_name().empty()
                                    ? kDefaultServingSignatureDefKey
                                    : request.model_spec().signature_name();
  auto iter = bundle->meta_graph_def.signature_def().find(signature_name);
  if (iter == bundle->meta_graph_def.signature_def().end()) {
    return errors::FailedPrecondition(strings::StrCat(
        "Serving signature key "", signature_name, "" not found."));
  }
  SignatureDef signature = iter->second;
  MakeModelSpec(request.model_spec().name(), signature_name,
                bundle.id().version, response->mutable_model_spec());
  std::vector<std::pair<string, Tensor>> input_tensors;
  std::vector<string> output_tensor_names;
  std::vector<string> output_tensor_aliases;
  TF_RETURN_IF_ERROR(PreProcessPrediction(signature, request, &input_tensors,
                                          &output_tensor_names,
                                          &output_tensor_aliases));
  std::vector<Tensor> outputs;
  RunMetadata run_metadata;
  TF_RETURN_IF_ERROR(bundle->session->Run(run_options, input_tensors,
                                          output_tensor_names, {}, &outputs,
                                          &run_metadata));
  return PostProcessPredictionResult(signature, output_tensor_aliases, outputs,
                                     response);
}

此代码使用存储的模型运行预测。就我而言,该存储的模型是RNN。

在实际运行的以下行之后:

TF_RETURN_IF_ERROR(bundle->session->Run(run_options, input_tensors,
                                      output_tensor_names, {}, &outputs,
                                      &run_metadata));

我想将RNN的状态保存到文件/内存,以便在每个预测后以后访问它们。我认为可以通过变量访问这些状态:

bundle->meta_graph_def

但不清楚如何确切访问RNN的状态值,然后将它们保存到文件中。

您必须通过会话而不是事实来获得值。在你的行中

bundle->session->Run(run_options, input_tensors,
                                  output_tensor_names, {}, &outputs,
                                  &run_metadata)

您不仅应该请求预测的结果,还应要求状态张量。请注意,该状态的值未在 session->Run的特定调用之后存储,仅存储变量,rnn的状态是计算值并在返回结果后删除。

我不使用C 接口,所以请原谅我在Python中提供代码示例,希望它仍然有用(我敢打赌,这不是您第一次必须忍受阅读Python示例(,在Python中,这看起来像:

prediction, state = sess.run([prediction_tensor, state_tensor], feed_dict=...)