• 【pointNet】基于pointNet的三维点云目标分类识别matlab仿真


    1.软件版本

    matlab2021a

    2.系统概述

    这里,采用的pointnet网络结构如下图所示:

            在整体网络结构中,

          首先进行set abstraction,这一部分主要即对点云中的点进行局部划分,提取整体特征,如图可见,在set abstraction中,主要有Sampling layer、Grouping layer、以及PointNet layer三层构成,sampling layer即完成提取中心点工作,采用fps算法,而在grouping中,即完成group操作,采用mrg或msg方法,最后对于提取出得点,使用pointnet进行特征提取。在msg中,第一层set abstraction取中心点512个,半径分别为0.1、0.2、0.4,每个圈内的最大点数为16,32,128。

    Sampling layer

    采样层在输入点云中选择一系列点,由此定义出局部区域的中心。采样算法使用迭代最远点采样方法 iterative farthest point sampling(FPS)。先随机选择一个点,然后再选择离这个点最远的点作为起点,再继续迭代,直到选出需要的个数为止相比随机采样,能更完整得通过区域中心点采样到全局点云

    Grouping layer

    目的是要构建局部区域,进而提取特征。思想就是利用临近点,并且论文中使用的是neighborhood ball,而不是KNN,是因为可以保证有一个fixed region scale,主要的指标还是距离distance。

    Pointnet layer

    在如何对点云进行局部特征提取的问题上,利用原有的Pointnet就可以很好的提取点云的特征,由此在Pointnet++中,原先的Pointnet网络就成为了Pointnet++网络中的子网络,层级迭代提取特征。

    3.部分核心程序

    1. clc;
    2. clear;
    3. close all;
    4. warning off;
    5. addpath(genpath(pwd));
    6. rng('default')
    7. %****************************************************************************
    8. %更多关于matlab和fpga的搜索“fpga和matlab”的CSDN博客:
    9. %matlab/FPGA项目开发合作
    10. %https://blog.csdn.net/ccsss22?type=blog
    11. %****************************************************************************
    12. dsTrain = PtCloudClassificationDatastore('train');
    13. dsVal = PtCloudClassificationDatastore('test');
    14. ptCloud = pcread('Chair.ply');
    15. label = 'Chair';
    16. figure;pcshow(ptCloud)
    17. xlabel("X");ylabel("Y");zlabel("Z");title(label)
    18. dsLabelCounts = transform(dsTrain,@(data){data{2} data{1}.Count});
    19. labelCounts = readall(dsLabelCounts);
    20. labels = vertcat(labelCounts{:,1});
    21. counts = vertcat(labelCounts{:,2});
    22. figure;histogram(labels);title('class distribution')
    23. rng(0)
    24. [G,classes] = findgroups(labels);
    25. numObservations = splitapply(@numel,labels,G);
    26. desiredNumObservationsPerClass = max(numObservations);
    27. filesOverSample=[];
    28. for i=1:numel(classes)
    29. if i==1
    30. targetFiles = {dsTrain.Files{1:numObservations(i)}};
    31. else
    32. targetFiles = {dsTrain.Files{numObservations(i-1)+1:sum(numObservations(1:i))}};
    33. end
    34. % Randomly replicate the point clouds belonging to the infrequent classes
    35. files = randReplicateFiles(targetFiles,desiredNumObservationsPerClass);
    36. filesOverSample = vertcat(filesOverSample,files');
    37. end
    38. dsTrain.Files=filesOverSample;
    39. dsTrain.Files = dsTrain.Files(randperm(length(dsTrain.Files)));
    40. dsTrain.MiniBatchSize = 32;
    41. dsVal.MiniBatchSize = dsTrain.MiniBatchSize;
    42. dsTrain = transform(dsTrain,@augmentPointCloud);
    43. data = preview(dsTrain);
    44. ptCloud = data{1,1};
    45. label = data{1,2};
    46. figure;pcshow(ptCloud.Location,[0 0 1],"MarkerSize",40,"VerticalAxisDir","down")
    47. xlabel("X");ylabel("Y");zlabel("Z");title(label)
    48. minPointCount = splitapply(@min,counts,G);
    49. maxPointCount = splitapply(@max,counts,G);
    50. meanPointCount = splitapply(@(x)round(mean(x)),counts,G);
    51. stats = table(classes,numObservations,minPointCount,maxPointCount,meanPointCount)
    52. numPoints = 1000;
    53. dsTrain = transform(dsTrain,@(data)selectPoints(data,numPoints));
    54. dsVal = transform(dsVal,@(data)selectPoints(data,numPoints));
    55. dsTrain = transform(dsTrain,@preprocessPointCloud);
    56. dsVal = transform(dsVal,@preprocessPointCloud);
    57. data = preview(dsTrain);
    58. figure;pcshow(data{1,1},[0 0 1],"MarkerSize",40,"VerticalAxisDir","down");
    59. xlabel("X");ylabel("Y");zlabel("Z");title(data{1,2})
    60. inputChannelSize = 3;
    61. hiddenChannelSize1 = [64,128];
    62. hiddenChannelSize2 = 256;
    63. [parameters.InputTransform, state.InputTransform] = initializeTransform(inputChannelSize,hiddenChannelSize1,hiddenChannelSize2);
    64. inputChannelSize = 3;
    65. hiddenChannelSize = [64 64];
    66. [parameters.SharedMLP1,state.SharedMLP1] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);
    67. inputChannelSize = 64;
    68. hiddenChannelSize1 = [64,128];
    69. hiddenChannelSize2 = 256;
    70. [parameters.FeatureTransform, state.FeatureTransform] = initializeTransform(inputChannelSize,hiddenChannelSize,hiddenChannelSize2);
    71. inputChannelSize = 64;
    72. hiddenChannelSize = 64;
    73. [parameters.SharedMLP2,state.SharedMLP2] = initializeSharedMLP(inputChannelSize,hiddenChannelSize);
    74. inputChannelSize = 64;
    75. hiddenChannelSize = [512,256];
    76. numClasses = numel(classes);
    77. [parameters.ClassificationMLP, state.ClassificationMLP] = initializeClassificationMLP(inputChannelSize,hiddenChannelSize,numClasses);
    78. numEpochs = 60;
    79. learnRate = 0.001;
    80. l2Regularization = 0.1;
    81. learnRateDropPeriod = 15;
    82. learnRateDropFactor = 0.5;
    83. gradientDecayFactor = 0.9;
    84. squaredGradientDecayFactor = 0.999;
    85. avgGradients = [];
    86. avgSquaredGradients = [];
    87. [lossPlotter, trainAccPlotter,valAccPlotter] = initializeTrainingProgressPlot;
    88. % Number of classes
    89. numClasses = numel(classes);
    90. % Initialize the iterations
    91. iteration = 0;
    92. % To calculate the time for training
    93. start = tic;
    94. % Loop over the epochs
    95. for epoch = 1:numEpochs
    96. % Reset training and validation datastores.
    97. reset(dsTrain);
    98. reset(dsVal);
    99. % Iterate through data set.
    100. while hasdata(dsTrain) % if no data to read, exit the loop to start the next epoch
    101. iteration = iteration + 1;
    102. % Read data.
    103. data = read(dsTrain);
    104. % Create batch.
    105. [XTrain,YTrain] = batchData(data,classes);
    106. % Evaluate the model gradients and loss using dlfeval and the
    107. % modelGradients function.
    108. [gradients, loss, state, acc] = dlfeval(@modelGradients,XTrain,YTrain,parameters,state);
    109. % L2 regularization.
    110. gradients = dlupdate(@(g,p) g + l2Regularization*p,gradients,parameters);
    111. % Update the network parameters using the Adam optimizer.
    112. [parameters, avgGradients, avgSquaredGradients] = adamupdate(parameters, gradients, ...
    113. avgGradients, avgSquaredGradients, iteration,learnRate,gradientDecayFactor, squaredGradientDecayFactor);
    114. % Update the training progress.
    115. D = duration(0,0,toc(start),"Format","hh:mm:ss");
    116. title(lossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D))
    117. addpoints(lossPlotter,iteration,double(gather(extractdata(loss))))
    118. addpoints(trainAccPlotter,iteration,acc);
    119. drawnow
    120. end
    121. % Create confusion matrix
    122. cmat = sparse(numClasses,numClasses);
    123. % Classify the validation data to monitor the tranining process
    124. while hasdata(dsVal)
    125. data = read(dsVal); % Get the next batch of data.
    126. [XVal,YVal] = batchData(data,classes);% Create batch.
    127. % Compute label predictions.
    128. isTrainingVal = 0; %Set at zero for validation data
    129. YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
    130. % Choose prediction with highest score as the class label for
    131. % XTest.
    132. [~,YValLabel] = max(YVal,[],1);
    133. [~,YPredLabel] = max(YPred,[],1);
    134. cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel);% Update the confusion matrix
    135. end
    136. % Update training progress plot with average classification accuracy.
    137. acc = sum(diag(cmat))./sum(cmat,"all");
    138. addpoints(valAccPlotter,iteration,acc);
    139. % Update the learning rate
    140. if mod(epoch,learnRateDropPeriod) == 0
    141. learnRate = learnRate * learnRateDropFactor;
    142. end
    143. reset(dsTrain); % Reset the training data since all the training data were already read
    144. % Shuffle the data at every epoch
    145. dsTrain.UnderlyingDatastore.Files = dsTrain.UnderlyingDatastore.Files(randperm(length(dsTrain.UnderlyingDatastore.Files)));
    146. reset(dsVal);
    147. end
    148. cmat = sparse(numClasses,numClasses); % Prepare sparse-double variable to do like zeros(2,2)
    149. reset(dsVal); % Reset the validation data
    150. data = readall(dsVal); % Read all validation data
    151. [XVal,YVal] = batchData(data,classes); % Create batch.
    152. % Classify the validation data using the helper function pointnetClassifier
    153. YPred = pointnetClassifier(XVal,parameters,state,isTrainingVal);
    154. % Choose prediction with highest score as the class label for
    155. % XTest.
    156. [~,YValLabel] = max(YVal,[],1);
    157. [~,YPredLabel] = max(YPred,[],1);
    158. % Collect confusion metrics.
    159. cmat = aggreateConfusionMetric(cmat,YValLabel,YPredLabel);
    160. figure;chart = confusionchart(cmat,classes);
    161. acc = sum(diag(cmat))./sum(cmat,"all")

    4.仿真结论

     

     

     

     

     5.参考文献

     [1][1] Qi C R ,  Su H ,  Mo K , et al. PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation[C]// 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2017.资源同名下载

  • 相关阅读:
    TypeScript(任意类型)
    NAT地址转换
    C++入门第八篇---STL模板---list的模拟实现
    Onedev v7.4.14 路径遍历漏洞分析(CVE-2022-38301)
    docker+nginx 安装部署修改资源目录配置文件和容器端口信息
    SQL不同类型分组排序
    Java 程序结构
    【博客506】k8s扩展调度器以支撑更灵活的GPU调度
    Redis数据库安全加固
    EDUSRC-记某擎未授权与sql注入
  • 原文地址:https://blog.csdn.net/ccsss22/article/details/125434324