Mex函数比相同的c++代码更快

Mex function faster than identical C++ code

本文关键字:c++ 代码 函数 Mex      更新时间:2023-10-16

当我在c++中运行程序时,它比MATLAB使用Mex函数调用的相同程序运行得慢。

我尝试了一些示例代码来测试,这证实了我的怀疑:

使用c++:

#include <stdio.h>
#include <ctime>
void process(int a[10000], int b[10000]) {
    const int dim[2] = {1, 10000};
    int barData[20000];
    clock_t begin = clock();
    for (int i = 0; i < dim[1]; i++) {
        for (int j = 0; j < i; j++) {
            barData[j] = a[i];
            barData[j] = b[i];
        }
    }
    clock_t end = clock();
    double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC;
    printf("%fn", elapsed_secs);
}
int main() {
    int a[10000], b[10000];
    process(a,b);
    return 0;
}

使用Mex函数:

#include <stdio.h>
#include "mex.h"
void process(const mxArray *first, const mxArray *second) {
    int* a = (int *)mxGetData(first);
    int* b = (int *)mxGetData(second);
    const int *dim = mxGetDimensions(first);
    const int dims[2] = {1,dim[1]*2};
    mxArray* bar = mxCreateNumericArray(2, dims, mxINT64_CLASS, mxREAL);
    int* barData = (int*)mxGetData(bar);
    clock_t begin = clock();
    for (int i = 0; i < dim[1]; i++) {
        for (int j = 0; j < i; j++) {
            barData[j] = a[i];
            barData[j] = b[i];
        }
    }
    clock_t end = clock();
    double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC;
    printf("%fn", elapsed_secs);
}
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    process(prhs[0], prhs[1]);
}

在MATLAB中调用,如下所示:

mex test.cpp -output foo
foo(rand(1,10000), rand(1,10000))

Mex函数给出~0.012s,而c++代码给出0.108s。这种趋势也适用于更大的数组大小。为什么会这样,有没有一种方法可以使c++代码以Mex函数速度运行?

正如@Praetorian在上面的评论中所说,你可能没有对c++代码进行优化。

下面是未经优化的代码的LLVMIR(伪汇编):

; ModuleID = 'test.cpp'
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
@_ZZ7processPiS_E3dim = internal constant [2 x i32] [i32 1, i32 10000],     align 4
@.str = private unnamed_addr constant [4 x i8] c"%fA0", align 1
; Function Attrs: uwtable
define void @_Z7processPiS_(i32* %a, i32* %b) #0 {
  %1 = alloca i32*, align 8
  %2 = alloca i32*, align 8
  %barData = alloca [20000 x i32], align 16
  %begin = alloca i64, align 8
  %i = alloca i32, align 4
  %j = alloca i32, align 4
  %end = alloca i64, align 8
  %elapsed_secs = alloca double, align 8
  store i32* %a, i32** %1, align 8
  store i32* %b, i32** %2, align 8
  %3 = call i64 @clock() #3
  store i64 %3, i64* %begin, align 8
  store i32 0, i32* %i, align 4
  br label %4
; <label>:4                                       ; preds = %34, %0
  %5 = load i32* %i, align 4
  %6 = load i32* getelementptr inbounds ([2 x i32]*             @_ZZ7processPiS_E3dim, i32 0, i64 1), align 4
  %7 = icmp slt i32 %5, %6
  br i1 %7, label %8, label %37
; <label>:8                                       ; preds = %4
  store i32 0, i32* %j, align 4
  br label %9
; <label>:9                                       ; preds = %30, %8
  %10 = load i32* %j, align 4
  %11 = load i32* %i, align 4
  %12 = icmp slt i32 %10, %11
  br i1 %12, label %13, label %33
; <label>:13                                      ; preds = %9
  %14 = load i32* %i, align 4
  %15 = sext i32 %14 to i64
  %16 = load i32** %1, align 8
  %17 = getelementptr inbounds i32* %16, i64 %15
  %18 = load i32* %17, align 4
  %19 = load i32* %j, align 4
  %20 = sext i32 %19 to i64
  %21 = getelementptr inbounds [20000 x i32]* %barData, i32 0, i64 %20
  store i32 %18, i32* %21, align 4
  %22 = load i32* %i, align 4
  %23 = sext i32 %22 to i64
  %24 = load i32** %2, align 8
  %25 = getelementptr inbounds i32* %24, i64 %23
  %26 = load i32* %25, align 4
  %27 = load i32* %j, align 4
  %28 = sext i32 %27 to i64
  %29 = getelementptr inbounds [20000 x i32]* %barData, i32 0, i64 %28
  store i32 %26, i32* %29, align 4
  br label %30
; <label>:30                                      ; preds = %13
  %31 = load i32* %j, align 4
  %32 = add nsw i32 %31, 1
  store i32 %32, i32* %j, align 4
  br label %9
; <label>:33                                      ; preds = %9
  br label %34
; <label>:34                                      ; preds = %33
  %35 = load i32* %i, align 4
  %36 = add nsw i32 %35, 1
  store i32 %36, i32* %i, align 4
  br label %4
; <label>:37                                      ; preds = %4
  %38 = call i64 @clock() #3
  store i64 %38, i64* %end, align 8
  %39 = load i64* %end, align 8
  %40 = load i64* %begin, align 8
  %41 = sub nsw i64 %39, %40
  %42 = sitofp i64 %41 to double
  %43 = fdiv double %42, 1.000000e+06
  store double %43, double* %elapsed_secs, align 8
  %44 = load double* %elapsed_secs, align 8
  %45 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x     i8]* @.str, i32 0, i32 0), double %44)
  ret void
}
; Function Attrs: nounwind
declare i64 @clock() #1
declare i32 @printf(i8*, ...) #2
; Function Attrs: uwtable
define i32 @main() #0 {
  %1 = alloca i32, align 4
  %a = alloca [10000 x i32], align 16
  %b = alloca [10000 x i32], align 16
  store i32 0, i32* %1
  %2 = getelementptr inbounds [10000 x i32]* %a, i32 0, i32 0
  %3 = getelementptr inbounds [10000 x i32]* %b, i32 0, i32 0
  call void @_Z7processPiS_(i32* %2, i32* %3)
  ret i32 0
}
attributes #0 = { uwtable "less-precise-fpmad"="false" "no-frame-    pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-    math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8"     "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { nounwind "less-precise-fpmad"="false" "no-frame-    pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-    math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8"     "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #2 = { "less-precise-fpmad"="false" "no-frame-pointer-    elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-    nans-fp-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-    math"="false" "use-soft-float"="false" }
attributes #3 = { nounwind }
!llvm.ident = !{!0}
!0 = !{!"clang version 3.6.2 (tags/RELEASE_362/final)"}

注意z7processspis非常长。

下面是优化-O3(现在在c++中通常是安全的):

; ModuleID = 'test.cpp'
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
@.str = private unnamed_addr constant [4 x i8] c"%fA0", align 1
; Function Attrs: nounwind uwtable
define void @_Z7processPiS_(i32* nocapture readnone %a, i32* nocapture readnone %b) #0 {
  %1 = tail call i64 @clock() #2
  %2 = tail call i64 @clock() #2
  %3 = sub nsw i64 %2, %1
  %4 = sitofp i64 %3 to double
  %5 = fdiv double %4, 1.000000e+06
  %6 = tail call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x     i8]* @.str, i64 0, i64 0), double %5)
  ret void
}
; Function Attrs: nounwind
declare i64 @clock() #1
; Function Attrs: nounwind
declare i32 @printf(i8* nocapture readonly, ...) #1
; Function Attrs: nounwind uwtable
define i32 @main() #0 {
  %1 = tail call i64 @clock() #2
  %2 = tail call i64 @clock() #2
  %3 = sub nsw i64 %2, %1
  %4 = sitofp i64 %3 to double
  %5 = fdiv double %4, 1.000000e+06
  %6 = tail call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x     i8]* @.str, i64 0, i64 0), double %5) #2
  ret i32 0
}
attributes #0 = { nounwind uwtable "less-precise-fpmad"="false" "no-frame-    pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false"     "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-    float"="false" }
attributes #1 = { nounwind "less-precise-fpmad"="false" "no-frame-pointer-    elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-            protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false"         }
attributes #2 = { nounwind }
!llvm.ident = !{!0}
!0 = !{!"clang version 3.6.2 (tags/RELEASE_362/final)"}

注::更习惯的写法是:

#include <iostream>
#include <vector>
#include <ctime>
using std::vector;
void process(vector<int> a, vector<int> b) {
    const pair<int,int> dim = {1, 10000};
    vector<int> barData(20000,0);
    clock_t begin = clock();
    for (int i = 0; i < dim.second; i++) {
        for (int j = 0; j < i; j++) {
            barData[j] = a[i];
            barData[j] = b[i];
        }
    }
    clock_t end = clock();
    std::cout << double(end-begin)/CLOCKS_PER_SEC << 'n';
}
int main() {
    vector<int> a(10000, 0), b(10000,0);
    process(a,b);
    return 0;
}