• 机器学习基础之《回归与聚类算法(6)—模型保存与加载》


    一、背景

    现在我们预测每次都要重新运行一遍模型。完整的流程应该是不断调整阈值重复计算。
    当训练或者计算好一个模型之后,那么如果别人需要我们提供结果预测,就需要保存模型(主要是保存算法的参数)。

    二、sklearn模型的保存和加载API

    1、import joblib
    保存:joblib.dump(rf, "test.pkl")
        rf:是预估器estimator
        test.pkl:是保存的名字
        将预估器序列化保存在本地    
    加载:estimator = joblib.load("test.pkl")

    2、代码

    1. from sklearn.datasets import load_boston
    2. from sklearn.model_selection import train_test_split
    3. from sklearn.preprocessing import StandardScaler
    4. from sklearn.linear_model import LinearRegression, SGDRegressor, Ridge
    5. from sklearn.metrics import mean_squared_error
    6. import joblib
    7. def linear1():
    8. """
    9. 正规方程的优化方法对波士顿房价进行预测
    10. """
    11. # 1、获取数据
    12. boston = load_boston()
    13. # 2、划分数据集
    14. x_train,x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=10)
    15. # 3、标准化
    16. transfer = StandardScaler()
    17. x_train = transfer.fit_transform(x_train)
    18. x_test = transfer.transform(x_test)
    19. # 4、预估器
    20. estimator = LinearRegression()
    21. estimator.fit(x_train, y_train)
    22. # 5、得出模型
    23. print("正规方程-权重系数为:\n", estimator.coef_)
    24. print("正规方程-偏置为:\n", estimator.intercept_)
    25. # 6、模型评估
    26. y_predict = estimator.predict(x_test)
    27. print("预测房价:\n", y_predict)
    28. error = mean_squared_error(y_test, y_predict)
    29. print("正规方程-均方误差为:\n", error)
    30. return None
    31. def linear2():
    32. """
    33. 梯度下降的优化方法对波士顿房价进行预测
    34. """
    35. # 1、获取数据
    36. boston = load_boston()
    37. # 2、划分数据集
    38. x_train,x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=10)
    39. # 3、标准化
    40. transfer = StandardScaler()
    41. x_train = transfer.fit_transform(x_train)
    42. x_test = transfer.transform(x_test)
    43. # 4、预估器
    44. estimator = SGDRegressor()
    45. estimator.fit(x_train, y_train)
    46. # 5、得出模型
    47. print("梯度下降-权重系数为:\n", estimator.coef_)
    48. print("梯度下降-偏置为:\n", estimator.intercept_)
    49. # 6、模型评估
    50. y_predict = estimator.predict(x_test)
    51. print("预测房价:\n", y_predict)
    52. error = mean_squared_error(y_test, y_predict)
    53. print("梯度下降-均方误差为:\n", error)
    54. return None
    55. def linear3():
    56. """
    57. 岭回归对波士顿房价进行预测
    58. """
    59. # 1、获取数据
    60. boston = load_boston()
    61. # 2、划分数据集
    62. x_train,x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=10)
    63. # 3、标准化
    64. transfer = StandardScaler()
    65. x_train = transfer.fit_transform(x_train)
    66. x_test = transfer.transform(x_test)
    67. # 4、预估器
    68. estimator = Ridge()
    69. estimator.fit(x_train, y_train)
    70. # 保存模型
    71. joblib.dump(estimator, "my_ridge.pkl")
    72. # 5、得出模型
    73. print("岭回归-权重系数为:\n", estimator.coef_)
    74. print("岭回归-偏置为:\n", estimator.intercept_)
    75. # 6、模型评估
    76. y_predict = estimator.predict(x_test)
    77. print("预测房价:\n", y_predict)
    78. error = mean_squared_error(y_test, y_predict)
    79. print("岭回归-均方误差为:\n", error)
    80. return None
    81. def linear4():
    82. """
    83. 岭回归对波士顿房价进行预测
    84. """
    85. # 1、获取数据
    86. boston = load_boston()
    87. # 2、划分数据集
    88. x_train,x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=10)
    89. # 3、标准化
    90. transfer = StandardScaler()
    91. x_train = transfer.fit_transform(x_train)
    92. x_test = transfer.transform(x_test)
    93. # 加载模型
    94. estimator = joblib.load("my_ridge.pkl")
    95. # 5、得出模型
    96. print("岭回归-权重系数为:\n", estimator.coef_)
    97. print("岭回归-偏置为:\n", estimator.intercept_)
    98. # 6、模型评估
    99. y_predict = estimator.predict(x_test)
    100. print("预测房价:\n", y_predict)
    101. error = mean_squared_error(y_test, y_predict)
    102. print("岭回归-均方误差为:\n", error)
    103. return None
    104. if __name__ == "__main__":
    105. # 代码1:正规方程的优化方法对波士顿房价进行预测
    106. linear1()
    107. # 代码2:梯度下降的优化方法对波士顿房价进行预测
    108. linear2()
    109. # 代码3:岭回归对波士顿房价进行预测
    110. linear3()
    111. # 代码4:加载模型
    112. linear4()

  • 相关阅读:
    深入底层学git:目录中包含的秘密
    Hadoop的安装和使用,Windows使用shell命令简单操作HDFS
    ROS2自学笔记:通信接口
    回溯算法 | 排列问题 | leecode刷题笔记
    探讨Acrel-1000DP分布式光伏系统的设计与应用-安科瑞 蒋静
    软件设计模式系列之十一——装饰模式
    Perl 中的模式匹配修饰符
    8年经验之谈 —— 如何用 JMeter 编写性能测试脚本?
    springboot+nodejs+vue+Elementui在线旅游管理系统
    【校招VIP】前端HTML考察之cavas、svg
  • 原文地址:https://blog.csdn.net/csj50/article/details/134394889