Python C++扩展 - 内存泄漏或访问冲突

Python C++ extension - memory leak or access violation

本文关键字:泄漏 访问冲突 内存 C++ 扩展 Python      更新时间:2023-10-16

我写了一个 Python C++ 扩展,但是它的一个函数有问题。此扩展提供的函数将 2 个数组作为输入,并生成一个作为输出。

我只留下了函数代码的相关部分

float* forward(float* input, float* kernels, npy_intp* input_dims, npy_intp* kernels_dims){
    float* output = new float[output_size];
    //some irrelevant matrix operation code
    return output;
}

和包装器:

static PyObject *module_forward(PyObject *self, PyObject *args)
{
    PyObject *input_obj, *kernels_obj;
    if (!PyArg_ParseTuple(args, "OO", &input_obj, &kernels_obj))
        return NULL;
    PyObject *input_array = PyArray_FROM_OTF(input_obj, NPY_FLOAT, NPY_IN_ARRAY);
    PyObject *kernels_array = PyArray_FROM_OTF(kernels_obj, NPY_FLOAT, NPY_IN_ARRAY);
    if (input_array == NULL || kernels_array == NULL) {
        Py_XDECREF(input_array);
        Py_XDECREF(kernels_array);
        return NULL;
    }

    float *input = (float*)PyArray_DATA(input_array);
    float *kernels = (float*)PyArray_DATA(kernels_array);
    npy_intp *input_dims = PyArray_DIMS(input_array);
    npy_intp *kernels_dims = PyArray_DIMS(kernels_array);
    /////////THE ACTUAL FUNCTION
    float* output = forward(input, kernels, input_dims, kernels_dims);

    Py_DECREF(input_array);
    Py_DECREF(kernels_array);
    npy_intp output_dims[4] = {input_dims[0], input_dims[1]-kernels_dims[0]+1, input_dims[2]-kernels_dims[1]+1, kernels_dims[3]};
    PyObject* ret_output = PyArray_SimpleNewFromData(4, output_dims, NPY_FLOAT, output);
    delete output;//<-----THE PROBLEMATIC LINE////////////////////////////
    PyObject *ret = Py_BuildValue("O", ret_output);
    Py_DECREF(ret_output);
    return ret;
}

我突出显示的删除运算符是魔术发生的地方:没有它,此函数会泄漏内存,并且由于内存访问违规而崩溃。

有趣的是,我编写了另一个方法,它返回两个数组。因此,该函数返回一个指向两个浮点*元素的浮点**:

float** gradients = backward(input, kernels, grads, input_dims, kernel_dims, PyArray_DIMS(grads_array));
Py_DECREF(input_array);
Py_DECREF(kernels_array);
Py_DECREF(grads_array);
PyObject* ret_g_input = PyArray_SimpleNewFromData(4, input_dims, NPY_FLOAT, gradients[0]);
PyObject* ret_g_kernels = PyArray_SimpleNewFromData(4, kernel_dims, NPY_FLOAT, gradients[1]);
delete gradients[0];
delete gradients[1];
delete gradients;
PyObject* ret_list = PyList_New(0);
PyList_Append(ret_list, ret_g_input);
PyList_Append(ret_list, ret_g_kernels);
PyObject *ret = Py_BuildValue("O", ret_list);
Py_DECREF(ret_g_input);
Py_DECREF(ret_g_kernels);
return ret;

请注意,第二个示例完美运行,没有崩溃或内存泄漏,同时在数组构建到 PyArray 对象后仍对数组调用delete

有人可以启发我这里发生了什么吗?

来自PyArray_SimpleNewFromData文档:

围绕给定指针指向的数据创建一个数组包装器

如果你创建一个带有PyArray_SimpleNewFromData数组,它将围绕你给它的数据创建一个包装器,而不是制作副本。这意味着它包装的数据必须比数组更耐用。 delete数据违反了这一点。

您有以下几种选择:

  • 您可以以不同的方式创建数组,这样您就不会只是围绕原始数据进行包装。
  • 您可以仔细控制对阵列的访问,并确保在delete数据之前结束其生存期。
  • 您可以创建一个拥有数据的 Python 对象,并在对象的生存期结束时delete数据,并使用 PyArray_SetBaseObject 设置数组对该对象的base,以便数组保持所有者对象的活动状态,直到数组本身死亡。