带有可选参数的 Boost.Python 构造函数

Boost.Python constructor with optional arguments

本文关键字:Boost Python 构造函数 参数      更新时间:2023-10-16

我想使用 Boost.Python 为带有可选参数的 C++ 构造函数创建一个 Python 包装器。我希望 Python 包装器的行为如下:

class Foo():
  def __init__(self, filename, phase, stages=None, level=0):
    """
    filename -- string
    phase -- int
    stages -- optional list of strings
    level -- optional int
    """
    if stages is None:
      stages = []
    # ...

如何使用 Boost.Python 执行此操作?我不知道如何使用make_constructor来做到这一点,而且我不知道如何使用raw_function制作构造函数。有没有比这更好的文档?

我的具体问题是尝试向这两个构造函数添加两个可选参数(阶段和级别):

https://github.com/BVLC/caffe/blob/rc3/python/caffe/_caffe.cpp#L76-L96

多亏了丹的评论,我找到了一个有效的解决方案。我将在这里复制大部分内容,因为有一些关于如何从bp::object中提取对象的有趣花絮,等等。

// Net constructor
shared_ptr<Net<Dtype> > Net_Init(string param_file, int phase,
    const int level, const bp::object& stages,
    const bp::object& weights_file) {
  CheckFile(param_file);
  // Convert stages from list to vector
  vector<string> stages_vector;
  if (!stages.is_none()) {
      for (int i = 0; i < len(stages); i++) {
        stages_vector.push_back(bp::extract<string>(stages[i]));
      }   
  }   
  // Initialize net 
  shared_ptr<Net<Dtype> > net(new Net<Dtype>(param_file,
      static_cast<Phase>(phase), level, &stages_vector));
  // Load weights
  if (!weights_file.is_none()) {
      std::string weights_file_str = bp::extract<std::string>(weights_file);
      CheckFile(weights_file_str);
      net->CopyTrainedLayersFrom(weights_file_str);
  }   
  return net;
}   
BOOST_PYTHON_MODULE(_caffe) {
  bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >("Net",
    bp::no_init)
    .def("__init__", bp::make_constructor(&Net_Init,
          bp::default_call_policies(), (bp::arg("network_file"), "phase",
            bp::arg("level")=0, bp::arg("stages")=bp::object(),
            bp::arg("weights_file")=bp::object())))
}   

生成的签名为:

__init__(boost::python::api::object, std::string network_file, int phase, 
   int level=0, boost::python::api::object stages=None,
   boost::python::api::object weights_file=None)

我可以像这样使用它:

net = caffe.Net('network.prototxt', weights_file='weights.caffemodel',
  phase=caffe.TEST, level=1, stages=['deploy'])

此处的拉取请求中提供了完整代码:https://github.com/BVLC/caffe/pull/3863