• Kmeans聚类 使用Pytorch和GPU加速


    目标

    sklearn库里面的kmeans算法默认运行在gpu上,运行效率较低。有时候需要在网络内动态的对特征进行分聚类。如果有基于Pytorch Tensor的kmeans实现则可以极大提升效率。

    经过检索发现ContrastiveSceneContexts有类似实现,可以参考该实现:

    环境

    pytorch, pykeops

    pip install pykeops  -i https://pypi.tuna.tsinghua.edu.cn/simple
    
    • 1

    代码

    import os
    import torch
    import numpy as np
    import glob
    import time
    import argparse
    import pykeops
    from pykeops.torch import LazyTensor
    pykeops.clean_pykeops() 
    
    def kmeans(pointcloud, k=10, iterations=10, verbose=True):
        n, dim = pointcloud.shape  # Number of samples, dimension of the ambient space
        start = time.time()
        clusters = pointcloud[:k, :].clone()  # Simplistic random initialization
        pointcloud_cuda = LazyTensor(pointcloud[:, None, :])  # (Npoints, 1, D)
    
        # K-means loop:
        for _ in range(iterations):
            clusters_previous = clusters.clone()
            clusters_gpu = LazyTensor(clusters[None, :, :])  # (1, Nclusters, D)
            distance_matrix = ((pointcloud_cuda - clusters_gpu) ** 2).sum(-1)  # (Npoints, Nclusters) symbolic matrix of squared distances
            cloest_clusters = distance_matrix.argmin(dim=1).long().view(-1)  # Points -> Nearest cluster
    
            # #points for each cluster
            clusters_count = torch.bincount(cloest_clusters, minlength=k).float()  # Class weights
            for d in range(dim):  # Compute the cluster centroids with torch.bincount:
                clusters[:, d] = torch.bincount(cloest_clusters, weights=pointcloud[:, d], minlength=k) / clusters_count
            
            # for clusters that have no points assigned
            mask = clusters_count == 0
            clusters[mask] = clusters_previous[mask]
    
        end = time.time()
    
        if verbose:
            print("K-means example with {:,} points in dimension {:,}, K = {:,}:".format(n, dim, k))
            print('Timing for {} iterations: {:.5f}s = {} x {:.5f}s\n'.format(
                    iterations, end - start, iterations, (end-start) / iterations))
        
        # nearest neighbouring search for each cluster
        cloest_points_to_centers = distance_matrix.argmin(dim=0).long().view(-1)
        return cloest_points_to_centers
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    Reference

    • Ji Hou, Benjamin Graham, Matthias Nießner, Saining Xie:
      Exploring Data-Efficient 3D Scene Understanding With Contrastive Scene Contexts. CVPR 2021: 15587-15597
    • https://github.com/facebookresearch/ContrastiveSceneContexts/blob/83515bef4754b3d90fc3b3a437fa939e0e861af8/downstream/semseg/lib/sampling_points.py#L28
  • 相关阅读:
    LF-YOLO: A Lighter and Faster YOLO for Weld Defect Detection of X-ray Image
    【王道】计算机网络网络层(三)
    Mybatis sql 控制台格式化
    Mysql索引
    HTML——5.表单、框架、颜色
    简单了解Vue及其指令
    DALL·E 2 文生图模型实践指南
    java计算机毕业设计小区宠物管理系统源程序+mysql+系统+lw文档+远程调试
    Redis入门到通关之数据结构解析-ZipList
    grpc设置连接存活时间及服务端健康检查
  • 原文地址:https://blog.csdn.net/a237072751/article/details/126968661