码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 手写LASSO回归python实现


    1. import numpy as np
    2. from matplotlib.font_manager import FontProperties
    3. from sklearn.datasets import make_regression
    4. from sklearn.model_selection import train_test_split
    5. import matplotlib.pyplot as plt
    6. class Lasso():
    7. def __init__(self):
    8. pass
    9. # 数据准备
    10. def prepare_data(self):
    11. # 生成样本数据
    12. X, y = make_regression(n_samples=40, n_features=80, random_state=0, noise=0.5)
    13. # 划分数据集
    14. X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
    15. return X_train, X_test, y_train.reshape(-1,1), y_test.reshape(-1,1)
    16. # 参数初始化
    17. def initialize_params(self, dims):
    18. w = np.zeros((dims, 1))
    19. b = 0
    20. return w, b
    21. # 定义L1损失函数
    22. def l1_loss(self, X, y, w, b, alpha):
    23. num_train = X.shape[0] # 样本数
    24. num_feature = X.shape[1] # 特征数
    25. y_hat = np.dot(X, w) + b # 回归预测数据
    26. # 计算损失
    27. loss = np.sum((y_hat - y) ** 2) / num_train + alpha * np.sum(np.abs(w)) # 修改此处
    28. # 计算梯度,即参数的变化
    29. dw = np.dot(X.T, (y_hat - y)) / num_train + alpha * np.sign(w) # 修改此处
    30. db = np.sum((y_hat - y)) / num_train
    31. return y_hat, loss, dw, db
    32. def lasso_train(self, X, y, learning_rate, epochs, alpha):
    33. loss_list = []
    34. w, b = self.initialize_params(X.shape[1])
    35. # 归一化特征
    36. X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
    37. for i in range(1, epochs):
    38. y_hat, loss, dw, db = self.l1_loss(X, y, w, b, alpha)
    39. # 更新参数
    40. w += -learning_rate * dw
    41. b += -learning_rate * db
    42. loss_list.append(loss)
    43. # if i % 300 == 0:
    44. # print('epoch %d loss %f' % (i, loss))
    45. params = {
    46. 'w': w,
    47. 'b': b
    48. }
    49. grads = {
    50. 'dw': dw,
    51. 'db': db
    52. }
    53. return loss, loss_list, params, grads
    54. # 根据计算的得到的参数进行预测
    55. def predict(self, X, params):
    56. w = params['w']
    57. b = params['b']
    58. y_pred = np.dot(X, w) + b
    59. return y_pred
    60. if __name__ == '__main__':
    61. lasso = Lasso()
    62. X_train, X_test, y_train, y_test = lasso.prepare_data()
    63. alphas=np.arange(0.01,0.11,0.01)
    64. wc=[]#统计参数w中绝对值小于0.1的个数,模拟稀疏度
    65. for alpha in alphas:
    66. # 参数:训练集x,训练集y,学习率,迭代次数,正则化系数
    67. loss, loss_list, params, grads = lasso.lasso_train(X_train, y_train, 0.02, 3000,alpha)
    68. w=np.squeeze(params['w'])
    69. count=np.sum(np.abs(w)<1e-1)
    70. wc.append(count)
    71. # 设置中文字体
    72. plt.rcParams['font.sans-serif'] = ['SimHei']
    73. plt.rcParams['axes.unicode_minus'] = False
    74. plt.figure(figsize=(10, 8))
    75. plt.plot(alphas, wc, 'o-')
    76. plt.xlabel('正则项系数',fontsize=15)
    77. plt.ylabel('参数w矩阵的稀疏度',fontsize=15)
    78. plt.show()

  • 相关阅读:
    Java框架(三)--Spring IoC容器与Bean管理(7)--基于注解配置IoC容器
    聊聊Java的垃圾回收机制
    2022年新一批获得能力评估CS认证证书的企业名单
    实现简易minishell
    在执行对 HDFS 中创建用户目录的指令时,回复的命令如下图所示
    安卓手机APP开发___设置闹钟
    《安富莱嵌入式周报》第283期:全开源逆向“爆破”硬件工具,Linux内核6.1将正式引入RUST语言,I3C培训教程,80款市场成熟的电感式位置传感器设计
    Docker swarm --集群和编排
    s21.云原生发展经历的阶段与未来发展趋势
    强强联合,波卡生态正成为物联网赛道关键入口
  • 原文地址:https://blog.csdn.net/qq_58158950/article/details/134435537
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号