基于另一个向量的相似度对一个向量进行条件平均

Conditional Averaging of a vector based on similarities in another vector c++

本文关键字:向量 一个 条件 另一个 相似      更新时间:2023-10-16

最好给出一个例子。

假设向量A包含:

A = {3  ,2 ,1 ,4  ,6 ,3 ,8 ,4}

和向量B由:

B = {1.5,2 ,2 ,1.5,3 ,3 ,3 ,2}

向量B的唯一值是{1.5, 2, 3}

我希望结果向量RESULT为:

RESULT[0] = Average(A given B=1.5) = Average(3,4)
RESULT[1] = Average(A given B=2 )  = Average(2,1,4)
RESULT[2] = Average(A given B=3 )  = Average(6,3,8)

计算这个最有效的方法是什么?我自己的方法是遍历B的唯一元素,对于它们中的每一个,遍历每个B值,试图匹配那个唯一的数字,并在每个匹配中不断求和向量A的相应元素,同时计算匹配的数量,这样我就可以找到平均值。

这太慢了。因为我的向量A有8M个元素,向量B由0.5M个唯一值组成。

这是一个懒惰的想法:以锁步方式遍历两个向量,并将结果聚合到一个单独的容器中。例如:

#include <cassert>
#include <cmath>
#include <iostream>
#include <map>
#include <utility>
std::map<double, std::pair<int, std::size_t>> m;
assert(A.size() == B.size());
for (std::size_t i = 0; i != A.size(); ++i)
{
    assert(!std::isnan(B[i]));
    auto & p = m[B[i]];
    p.first += A[i];
    p.second += 1;
}

最后只报告结果:

for (const auto & p : m)
    std::cout << "Average for bin " << p.first << " is "
              << static_cast<double>(p.second.first) / p.second.second
              << "n";

(注意键值不能是NaN:在有序映射中,NaN不是严格排序的一部分;在无序映射中,它不与自身相等。)

您可以使用(哈希)表进行循环:参见Live On Coliru

int main()
{
    vector<int>    A = {3  ,2 ,1 ,4  ,6 ,3 ,8 ,4};
    vector<double> B = {1.5,2 ,2 ,1.5,3 ,3 ,3 ,2};
    assert(A.size() == B.size());
    struct accum { 
        uintmax_t sum = 0; 
        size_t number_of_samples = 0; 
        void sample(int val) { sum += val; ++number_of_samples; }
    };
    map<double, accum> average_state;
    for(size_t i = 0; i<B.size(); ++i)
        average_state[B[i]].sample(A[i]);
    for(auto& entry : average_state)
    {
        accum& state = entry.second;
        double average = static_cast<double>(state.sum) / state.number_of_samples;
        std::cout << "Bucket " << entry.first << "taverage of " << state.number_of_samples << " samples:t" << average << "n";
    }
}

打印

Bucket 1.5  average of 2 samples:   3.5
Bucket 2    average of 3 samples:   2.33333
Bucket 3    average of 3 samples:   5.66667