计算卷积的最快方法

Fastest method for calculating convolution

本文关键字:方法 卷积 计算      更新时间:2023-10-16

有人知道计算卷积的最快方法吗?不幸的是,我处理的矩阵非常大(500x500x200),如果我在MATLAB中使用convn,则需要很长时间(我必须在嵌套循环中迭代此计算)。所以,我使用了卷积和FFT,它现在更快了。但是,我仍在寻找一种更快的方法。知道吗?

如果内核是可分离的,那么通过执行多个顺序1D卷积将实现最大的速度增益。

MathWorks的Steve Eddins在他的博客上描述了当内核在MATLAB环境中是可分离的时,如何利用卷积的结合性来加快卷积。对于P-by-Q内核,与2D卷积相比,执行两个单独和顺序卷积的计算优势是PQ/(P+Q),这对应于9x9内核的4.5x和15x15内核的~11xEDIT:在这篇问答中,我们无意中展示了这种差异;A.

为了弄清楚内核是否是可分离的(即两个向量的外积),博客继续描述如何检查内核是否与SVD可分离,以及如何获得1D内核。他们的例子是2D内核。对于N维可分离卷积的解决方案,请检查此FEX提交。


另一个值得指出的资源是Intel对3D卷积的SIMD(SSE3/SSE4)实现,它包括源代码和演示。该代码适用于16位整数。除非你转向GPU(例如cuFFT),否则很难比英特尔的实现更快,英特尔的实现也包括英特尔MKL。MKL文档的这一页底部有一个3D卷积(单精度浮点)的示例(链接已修复,现在镜像在https://stackoverflow.com/a/27074295/2778484)。

您可以尝试重叠添加和重叠保存方法。它们包括将输入信号分解成更小的块,然后使用上面的任何一种方法。

FFT很可能是最快的方法,我可能错了,尤其是如果你使用MATLAB中的内置例程或C++中的库。除此之外,将输入信号分解成更小的块应该是一个不错的选择

我有两种方法来计算fastconv

2比1 好

1-armadillo你可以用这个代码使用armadillo库来煅烧conv

cx_vec signal(1024,fill::randn);
cx_vec code(300,fill::randn);
cx_vec ans = conv(signal,code);

2-使用fftw-ans-sigpack和armadillo库来煅烧快速conv,这样您就必须在构造函数中初始化代码的fft

FastConvolution::FastConvolution(cx_vec inpCode)
{
    filterCode = inpCode;
    fft_w = NULL;
}

cx_vec FastConvolution::filter(cx_vec inpData)
{
int length = inpData.size()+filterCode.size();
    if((length & (length - 1)) == 0)
    {
    }
    else
    {
        length = pow(2 , (int)log2(length) + 1);
    }
    if(length != fftCode.size())
        initCode(length);
    static cx_vec zeroPadedData;
    if(length!= zeroPadedData.size())
    {
        zeroPadedData.resize(length);
    }
    zeroPadedData.fill(0);
    zeroPadedData.subvec(0,inpData.size()-1) = inpData;

    cx_vec fftSignal = fft_w->fft_cx(zeroPadedData);
    cx_vec mullAns = fftSignal % fftCode;
    cx_vec ans = fft_w->ifft_cx(mullAns);
    return ans.subvec(filterCode.size(),inpData.size()+filterCode.size()-1);
}
void FastConvolution::initCode(int length)
{
    if(fft_w != NULL)
    {
        delete fft_w;
    }
    fft_w = new sp::FFTW(length,FFTW_ESTIMATE);
    cx_vec conjCode(length,fill::zeros);
    fftCode.resize(length);
    for(int i = 0; i < filterCode.size();i++)
    {
        conjCode.at(i) = filterCode.at(filterCode.size() - i - 1);
    }
    conjCode = conj(conjCode);
    fftCode = fft_w->fft_cx(conjCode);
}