使用c++ mex函数从matlab中获取输入参数

getting input parameters from matlab using C++ mex Function

本文关键字:获取 输入 参数 matlab c++ mex 函数 使用      更新时间:2023-10-16
double learning_rate = 1;
int training_epochs = 1;
int k = 1;
int train_S = 6;
int test_S = 6;
int visible_E = 6;
int hidden_E = 6;
// training data
int train_X[6][6] = {
    {1, 1, 1, 0, 0, 0},
    {1, 0, 1, 0, 0, 0},
    {1, 1, 1, 0, 0, 0},
    {0, 0, 1, 1, 1, 0},
    {0, 0, 1, 1, 1, 0},
    {0, 0, 1, 1, 1, 0}
};

上面的代码是输入参数我给我的函数。但是我想把它们转换成我的mexFunction中的一个函数,然后简单地调用它们。matlab端有如下

clear *
close all
clc
%% Load the data

X=    [ 1, 1, 1, 0, 0, 0; ...
        1, 0, 1, 0, 0, 0; ...
        1, 1, 1, 0, 0, 0; ...
        0, 0, 1, 1, 1, 0; ...
        0, 0, 1, 1, 1, 0; ...
        0, 0, 1, 1, 1, 0];
%% Define Parameters
numHiddenUnits = 6;
numIterations = 1000;
kCD = 1;
%% Compute the RBM
x = RBM(X, numHiddenUnits, numIterations, kCD);

标量输入参数相当简单。矩阵输入有点棘手,因为它们使用老式的Fortran列主顺序,您可能需要在将数据发送到函数之前对其进行转置。下面是一个例子,你需要填入空格:

/*=========================================================
 * Built on:
 * matrixDivide.c - Example for illustrating how to use 
 * LAPACK within a C MEX-file.
 *
 * This is a MEX-file for MATLAB.
 * Copyright 2009 The MathWorks, Inc.
 *=======================================================*/
/* $Revision: 1.1.6.2 $ $Date: 2009/05/18 19:50:18 $ */
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    double * pX, * pNumHiddenUnits, * pNumIter, * pkCD; /* pointers to inputs */
    double * pOutput;  /* output arugments */
    mwSignedIndex m,n;     /* matrix dimensions */ 
    int i, j;
      /* Check for proper number of arguments. */
    if ( nrhs != 4)
    {
        mexErrMsgIdAndTxt("MATLAB:RBM:rhs",
            "This function requires 4 inputs.");
    }
    pX = mxGetPr(prhs[0]); /* pointer to first input, X matrix */
    pNumHiddenUnits = mxGetPr(prhs[1]); /* pointer to second input, scalar hidden units */
    pNumIter = mxGetPr(prhs[2]); /* pointer to third input, scalar number of iterations */
    pkCD = mxGetPr(prhs[3]); /* pointer to third input, scalar kCD */
    /* dimensions of input matrix */
    m = (mwSignedIndex)mxGetM(prhs[0]);  
    n = (mwSignedIndex)mxGetN(prhs[0]);
    /* Validate input arguments */
    if (m < 1 && n < 1)
    {
        mexErrMsgIdAndTxt("MATLAB:RBM:notamatrix",
            "X must be a matrix.");
    }
    plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL);
    pOutput = mxGetPr(plhs[0]);
    for (i = 0; i < n; ++i)
    {
      for (j = 0; j < m; ++j)
      {
         int index = j * n + i;
         pOutput[index] = pX[i * m + j];
      }
    }
  }
  /* */