• caffe 统计分析显示权重分布脚本


    先上效果图如下:

    1. import numpy as np
    2. import matplotlib.pyplot as plt
    3. import random
    4. def Statistics_weight(save_dir, type, name, weight):
    5. if not os.path.exists(save_dir):
    6. os.mkdir(save_dir)
    7. weight_abs = abs(weight)
    8. max_val = np.max(weight_abs)
    9. min_val = np.min(weight_abs)
    10. ################################################
    11. x_data = [0, 1e-25, 1e-15, 1e-10, 1e-5, 1e-1, 1, 2, 10]
    12. x_data_show = ["0", "1e-25", "1e-15", "1e-10", "1e-5", "1e-1", "1", "2", "10"]
    13. y_data = []
    14. for i in range(len(x_data)):
    15. if 0 == i:
    16. tmp0 = weight_abs >= 0
    17. else:
    18. tmp0 = weight_abs >= x_data[i-1]
    19. if 0 == x_data[i]:
    20. tmp1 = weight_abs <= x_data[i]
    21. else:
    22. tmp1 = weight_abs < x_data[i]
    23. pos_right = (np.multiply(tmp0, tmp1)).sum()
    24. ratio = pos_right * 1.0 / weight_abs.size
    25. y_data.append(ratio)
    26. ################################################
    27. # print(x_data)
    28. # print(y_data)
    29. plt.figure(name)
    30. # 画图,plt.bar()可以画柱状图
    31. for i in range(len(x_data)):
    32. plt.bar(x_data_show[i], y_data[i])
    33. for a, b in zip(x_data_show, y_data):
    34. plt.text(a, b + 0.005, ("%.2f" % b), ha='center', va='bottom', fontsize=11)
    35. # 设置图片名称
    36. plt.title(type + "_" + name)
    37. # 设置x轴标签名
    38. plt.xlabel("value")
    39. # 设置y轴标签名
    40. plt.ylabel("ratio")
    41. plt.savefig(os.path.join(save_dir, type + "_" + name+".png"))
    42. # 显示
    43. # plt.show()
    44. ####conv
    45. print("==========>>conv" * 5)
    46. total_weight = 0
    47. total_weight_avail = 0
    48. for layer_para_name, para in net.params.items():
    49. if "bn" in layer_para_name or "scale" in layer_para_name or "Scale" in layer_para_name or "bias" in layer_para_name:
    50. continue
    51. Statistics_weight("/media/xxx_sparse/caffe-jacinto/0000/deply/show/0930/0930_L1+sprse", "L1+sparse", layer_para_name, abs(para[0].data))
    52. weights_np = abs(para[0].data) # para[0]weight para[1]bias 2 128 3 3
    53. weights_np_0 = weights_np[0]
    54. tmp_2 = weights_np <= 0.2
    55. ratio_123 = tmp_2.sum() * 1.0 / weights_np.size
    56. total_weight += weights_np.size
    57. tmp = weights_np > T
    58. total_weight_avail += tmp.sum()
    59. ratio_zero = (1 - (tmp.sum() * 1.0 / weights_np.size))
    60. print("layer_para_name=", layer_para_name, " ratio_zero=", ratio_zero)
    61. print("ratio_conv_avail_weight=", total_weight_avail * 1.0 / total_weight, " ratio_conv_not_avail_weight=",
    62. 1 - total_weight_avail * 1.0 / total_weight)
    63. ##################################

    c++ 加在blob.hpp里面的代码:

    1. double statistics_weight(const string name, int start, int n, const float &max_threshold_value, const float &threshold_fraction_selected)
    2. {
    3. const double* data_vec = cpu_data<double>() + start;
    4. double max_tmp = -DBL_MIN;
    5. // double min_tmp = -DBL_MAX;
    6. // cv::rectangle();
    7. for(int i=0; i< n; i++)
    8. {
    9. max_tmp = abs(data_vec[i]) > max_tmp ? abs(data_vec[i]) : max_tmp;
    10. }
    11. int split = 10;
    12. float each_ = max_tmp / split;
    13. std::vector<int> Histogram_(split, 0);
    14. for(int i=0; i< n; i++)
    15. {
    16. int idx = abs(data_vec[i]) / each_;
    17. if(split == idx)
    18. {
    19. idx -= 1;
    20. }
    21. Histogram_[idx] += 1;
    22. }
    23. int height_img = 500;
    24. cv::Mat hist(height_img, height_img*1.8, CV_8UC3, cv::Scalar(0,0,0));
    25. int T_hist_width = 60;
    26. int T_hist_gap = T_hist_width + 20;
    27. for(int i=0;i<split;i++)
    28. {
    29. float ratio = Histogram_[i] * 1.0 / n;
    30. int height = ratio * height_img;
    31. cv::Point pt_tl = cv::Point(i*T_hist_gap, height_img - height);
    32. cv::Point pt_br = cv::Point(i*T_hist_gap+T_hist_width,height_img - 0);
    33. cv::rectangle(hist, pt_tl, pt_br, cv::Scalar(255,0,0), -1);
    34. cv::putText(hist, std::to_string(ratio*100.0) + "%", cv::Point(pt_tl.x, pt_tl.y - 30), cv::FONT_HERSHEY_PLAIN, 1, cv::Scalar(0,255,255),1);
    35. string str_each_1 = std::to_string((each_ * (i+1)));
    36. int pos_decimal_point = str_each_1.find(".");
    37. string str_each_new = str_each_1.substr(0,pos_decimal_point+3);
    38. cv::putText(hist, str_each_new, cv::Point((pt_tl.x+pt_br.x)/2-5, height_img), cv::FONT_HERSHEY_PLAIN, 1, cv::Scalar(0,0,255),1);
    39. cv::putText(hist, "max_threshold_value="+ std::to_string(max_threshold_value), cv::Point(hist.cols*0.25, 50), cv::FONT_HERSHEY_PLAIN, 1, cv::Scalar(0,255,0),1);
    40. cv::putText(hist, "threshold_fraction_selected="+ std::to_string(threshold_fraction_selected*100) + "%", cv::Point(hist.cols*0.25, 120), cv::FONT_HERSHEY_PLAIN, 1, cv::Scalar(0,255,0),1);
    41. }
    42. cv::imshow("hist_"+name,hist);
    43. cv::waitKey(0);
    44. return max_tmp;
    45. }

    c++的在net.cpp的稀疏代码中调用:

    1. for(int c=0; c<no; c++) {
    2. // LOG(INFO) <<"=========>c="<<c;
    3. int weight_count_channel = ni * kernel_shape_data[0] * kernel_shape_data[1] / num_group;
    4. int start_index = weight_count_channel * c;
    5. float max_abs = std::abs(conv_weights.max(start_index, weight_count_channel));
    6. float min_abs = std::abs(conv_weights.min(start_index, weight_count_channel));
    7. float max_abs_value = std::max<float>(max_abs, min_abs);
    8. float step_size = max_abs_value * threshold_step_factor;
    9. float max_threshold_value = std::min<float>(std::min<float>(threshold_value_max, max_abs_value*threshold_value_maxratio), max_abs_value);
    10. float aa = conv_weights.statistics_weight(layer_name, start_index, weight_count_channel, max_threshold_value, threshold_fraction_selected);
    11. bool verbose_th_val = false;
    12. if(verbose && verbose_th_val || 0) {
    13. if ((max_abs_value*threshold_value_maxratio) > threshold_value_max) {
    14. LOG(INFO) << "threshold_value_max " << threshold_value_max;
    15. LOG(INFO) << "threshold_value_maxratio " << threshold_value_maxratio;
    16. LOG(INFO) << "max_abs_value*threshold_value_maxratio " << (max_abs_value*threshold_value_maxratio);
    17. LOG(INFO) << "final threshold_value used" << max_threshold_value;
    18. }
    19. }
  • 相关阅读:
    【计算机视觉40例】案例15:KNN数字识别
    考研数据结构——(图)
    05-01 jdk,tomcat,mariadb数据库和profile多环境
    【QML】在QML中布局的四种方法
    制造企业如何满足客户需求?精益生产教您三招
    uniapp项目实践总结(二十六)安卓应用商店上架教程
    Java面试题相关
    Shiro框架详解
    Kubernetes(22):Ingress详解
    代码随想录算法训练营第四天 | 24. 两两交换链表中的节点、19.删除链表的倒数第N个节点、160.链表相交、142.环形链表II
  • 原文地址:https://blog.csdn.net/yang332233/article/details/127128575