K-means

K-means 是我们最常用的基于欧式距离的聚类算法,其认为两个目标的距离越近,相似度越大。

算法步骤:

1.选择初始化的k个样本作为初始聚类中心 : a = a 1 , a 2 , . . . a k a=a_1,a_2,...a_k a=a1,a2,...ak;

2.针对数据集中每个样本 x i x_i xi,计算它到k个聚类中心的距离,并将其分到距离最小的聚类中心所对应的类中;

3.针对每个类别 a j a_j aj,重新计算它的聚类中心 a j = 1 c i ∑ x ∈ c i x a_j=\frac{1}{c_i}\sum_{x\in{c_i}}x aj=ci1xcix(即属于该类的所有样本的质心)

4.重复上面2,3两步操作,直到达到某个终止条件(迭代次数,最小误差变化等)

伪码:

获取数据 n 个 m 维的数据 //n为样本点数
随机生成 K 个 m 维的点 //k为簇的数目
while(t) //t为迭代次数
    for(int i=0;i < n;i++)
        for(int j=0;j < k;j++)
            计算点 i 到类 j 的距离
    for(int i=0;i < k;i++)
        1. 找出所有属于自己这一类的所有数据点
        2. 把自己的坐标修改为这些数据点的中心点坐标
end

2.1 优点

  • 容易理解,聚类效果不错,虽然是局部最优, 但往往局部最优就够了;
  • 处理大数据集的时候,该算法可以保证较好的伸缩性;
  • 当簇近似高斯分布的时候,效果非常不错;
  • 算法复杂度低。

2.2 缺点

  • K 值需要人为设定,不同 K 值得到的结果不一样;
  • 对初始的簇中心敏感,不同选取方式会得到不同结果;
  • 对异常值敏感;
  • 样本只能归为一类,不适合多分类任务;
  • 不适合太离散的分类、样本类别不平衡的分类、非凸形状的分类。

Faiss之聚类源码解析

三个步骤:初始化,实例化量化索引,训练

1.初始化:

/** Class for the clustering parameters. Can be passed to the
 * constructor of the Clustering object.
 */
struct ClusteringParameters {
    int niter;          ///< clustering iterations 每一次聚类需要迭代的次数
    int nredo;          ///< redo clustering this many times and keep best  训练的时候,聚类的次数
 
    bool spherical;     ///< do we want normalized centroids? 是否需要归一化
    bool int_centroids; ///< round centroids coordinates to integer
    
    int min_points_per_centroid; ///< otherwise you get a warning  每一簇最小样本数,低于这个数会warning,但是还是会继续
    int max_points_per_centroid;  ///< to limit size of dataset每一簇最大样本数,超过这个数会采样。
};

struct Clustering: ClusteringParameters {
    typedef Index::idx_t idx_t;
    size_t d;              ///< dimension of the vectors 参与聚类的向量维度
    size_t k;              ///< nb of centroids 聚类中心的个数
 
    /** centroids (k * d)
     * if centroids are set on input to train, they will be used as initialization
     */ 
    std::vector<float> centroids; // 聚类中心向量构成的一维vector
}

2.实例化Flat索引

Flat索引是faiss中最简单的索引,其实可以看成待搜索数据组成的一个list。用该索引搜索最近的top n个向量时,采用堆的方式来搜索。

Flat索引在聚类中的作用为:聚类过程中,需要找到离每个样本点最近的聚类中心,而一般聚类中心不会太多,使用最简单的flat索引就已足够。所以这里利用一个Flat索引存储聚类中心,通过样本在Flat索引中找寻top1 的点,即为该样本最近的聚类中心。

3.训练

void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
                                const Index * codec, Index & index,
                                const float *weights) 

3.1数据校验

首先,必须满足以下条件才能继续:

1. 训练的数据集的数量不得低于聚类中心数

2. Flat的维数必须与训练数据集的维数一致

若训练数据集的总量小于kmin_points_per_centroid, 则发起warning,不过聚类还会继续。如果总量大于kmax_points_per_centroid,则采样k*max_points_per_centroid个数据来替代原有的训练数据集。

if (nx > k * max_points_per_centroid) {
        uint8_t *x_new;
        float *weights_new;
        nx = subsample_training_set (*this, nx, x, line_size, weights,
                                &x_new, &weights_new);
        del1.reset (x_new); x = x_new;
        del3.reset (weights_new); weights = weights_new;
    } else if (nx < k * min_points_per_centroid) {
        fprintf (stderr,
                 "WARNING clustering %" PRId64 " points to %zd centroids: "
                 "please provide at least %" PRId64 " training points\n",
                 nx, k, idx_t(k) * min_points_per_centroid);
    }

3.2 初始化并迭代更新聚类中心

for (int redo = 0; redo < nredo; redo++) { // nredo 前文介绍过,聚类的次数
 
        ...... // 每一次聚类之前,初始化聚类中心
        for (int i = 0; i < niter; i++) {
            ...... // 迭代更新聚类中心
        }
    }

1)初始化聚类中心:随机打散, 取k个点,拷贝至聚类中心数组,并添加到flat index。

    rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
     for (int i = n_input_centroids; i < k ; i++) {
           memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
     }
    index.add (k, centroids.data());

2)迭代更新聚类中心

        index.search (nx, reinterpret_cast<const float *>(x), 1, dis.get(), assign.get());
       // accumulate error
       err = 0;
       for (int j = 0; j < nx; j++) {
              err += dis[j];
        }

通过flat index查询距离每个向量最近的聚类中心点,并将聚类中心id存在assign数组中,与聚类中心的距离存在dis数组中。比如序号为3的向量最近的聚类中心id为1,则assigin[3] = 1,dis[3]则是两者的距离。err是总偏差,保存所有的样本点到最近的聚类中心距离之和,用于最后寻找最优的聚类中心结果。

// 计算每一簇的中心,并更新 
/*计算每个簇的所有向量的总和,以及向量个数,得到每一簇的均值,即为新的聚类中心。为了提升计算的速度,将簇按照线程数分段,每一个线程计算对应分段的簇。举个例子,现在有10个线程,100个簇,那么0号线程计算0-9号簇,1号线程计算10-19号簇,以此类推。*/
 compute_centroids (  d, k, nx, k_frozen, x, codec, assign.get(), weights, hassign.data(), centroids.data() ); 
 
// 对于新生成的簇,有的簇可能没有向量,取一个向量较多的簇分割成两个小簇
/*找出数量为0的簇,并找出一个较大的簇,将其平均分成两份,并更新两个小簇对应的中心点*/
int nsplit = split_clusters (  d, k, nx, k_frozen, hassign.data(), centroids.data() );

通过这两个函数,可以得到更新后的中心,再将flat index 清零,并将新的中心点添加到index,作为下一次迭代的搜索索引。

这样直到迭代次数结束,选出距离偏差err最小的一次训练结果作为聚类的最后结果。

Logo

鸿蒙生态一站式服务平台。

更多推荐