目录
本文以糖尿病数据集diabetes为基础进行线性回归训练:
- """
- @Title: 收集数据
- @Time: 2024/3/11
- @Author: Michael Jie
- 收集数据和预处理:
- 1、收集数据;
- 2、数据可视化;
- 3、数据清洗;
- 4、特征工程;
- 5、构建特征集和标签集(仅监督学习需要);
- 6、拆分训练集和测试集。
- """
-
- import sklearn.datasets as ds
- import pandas as pd
-
- # 加载并返回糖尿病数据集(回归)
- diabetes = ds.load_diabetes(
- # 若为True,返回(data, target)元组,而非Bunch对象
- return_X_y=False,
- # 若为True,以pandas DataFrame/Series形式返回数据集
- as_frame=False,
- # 若为True,返回归一化后的特征集
- scaled=False
- )
-
- # Bunch对象本质是一个字典
- print(diabetes.keys())
- """
- dict_keys([
- 'data', # 特征集
- 'target', # 标签集
- 'frame', # 包含特征值和标签的数组,当as_frame=True时存在
- 'DESCR', # 数据集描述
- 'feature_names', # 特征集列名
- 'data_filename', # 内存中的特征集文件名
- 'target_filename', # 内存中的标签集文件名
- 'data_module'
- ])
- """
-
- # 特征集
- data = diabetes.data
- print(type(data), data.shape)
- """
- (442, 10)
- """
- feature_names = diabetes.feature_names
- print(feature_names, type(feature_names))
- """
- ['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']
- """
-
- # 标签集
- target = diabetes.target
- print(type(target), target.shape)
- """
- (442,)
- """
-
- # 数据集描述
- print(diabetes.DESCR)
- """
- Diabetes dataset
- ----------------
- Ten baseline variables, age, sex, body mass index, average blood
- pressure, and six blood serum measurements were obtained for each of n =
- 442 diabetes patients, as well as the response of interest, a
- quantitative measure of disease progression one year after baseline.
- **Data Set Characteristics:**
- :Number of Instances: 442
- :Number of Attributes: First 10 columns are numeric predictive values
- :Target: Column 11 is a quantitative measure of disease progression one year after baseline
- :Attribute Information:
- - age age in years
- - sex
- - bmi body mass index
- - bp average blood pressure
- - s1 tc, total serum cholesterol
- - s2 ldl, low-density lipoproteins
- - s3 hdl, high-density lipoproteins
- - s4 tch, total cholesterol / HDL
- - s5 ltg, possibly log of serum triglycerides level
- - s6 glu, blood sugar level
- Note: Each of these 10 feature variables have been mean centered and scaled by the standard deviation times the square root of `n_samples` (i.e. the sum of squares of each column totals 1).
- Source URL:
- https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html
- For more information see:
- Bradley Efron, Trevor Hastie, Iain Johnstone and Robert Tibshirani (2004) "Least Angle Regression," Annals of Statistics (with discussion), 407-499.
- (https://web.stanford.edu/~hastie/Papers/LARS/LeastAngle_2002.pdf)
- """
-
- # 下载数据集
- data_csv = pd.DataFrame(data=data, columns=feature_names)
- target_csv = pd.DataFrame(data=target, columns=['target'])
- diabetes_csv = pd.concat([data_csv, target_csv], axis=1)
- diabetes_csv.to_csv(r'diabetes_datasets.csv', index=False)
- """
- @Title: 数据可视化
- @Time: 2024/3/11
- @Author: Michael Jie
- """
-
- import pandas as pd
- import matplotlib.pyplot as plt
-
- # 读取数据
- csv = pd.read_csv(r'diabetes_datasets.csv')
- print(csv.shape) # (442, 11)
-
- # 可视化数据
- plt.figure(figsize=(19.2, 10.8))
- for i in range(csv.shape[1] - 1):
- plt.subplot(2, 5, i + 1).scatter(csv[csv.columns[i]], csv["target"])
-
- # 保存图片
- plt.savefig(r'diabetes_datasets.png')
- # plt.show()
- """
- @Title: 数据清洗
- @Time: 2024/3/11
- @Author: Michael Jie
- """
-
- import pandas as pd
-
- """
- 1、处理缺失数据:剔除残缺数据,也可以用平均值、随机值或者0来补值;
- 2、处理重复数据:删除完全相同的重复数据处理;
- 3、处理错误数据:处理逻辑错误数据;
- 4、处理不可用数据:处理格式错误数据。
- """
-
- # 读取数据
- csv = pd.read_csv(r'diabetes_datasets.csv')
-
- # 统计NaN出现的次数
- print(csv.isna().sum())
- """
- age 0
- sex 0
- bmi 0
- bp 0
- s1 0
- s2 0
- s3 0
- s4 0
- s5 0
- s6 0
- target 0
- dtype: int64
- """
- """
- @Title: 特征工程
- @Time: 2024/3/11
- @Author: Michael Jie
- """
-
- import numpy as np
- import sklearn.datasets as ds
-
-
- # 标准化
- def z_score_normalization(x, axis=0):
- x = np.array(x)
- x = (x - np.mean(x, axis=axis)) / np.std(x, axis=axis)
- return x
-
-
- # 若为True,返回归一化后的特征集
- diabetes_pre = ds.load_diabetes(scaled=True)
- print(diabetes_pre.data)
-
- # 手动标准化特征集
- diabetes = ds.load_diabetes(scaled=False)
- print(z_score_normalization(diabetes.data))
无。
- """
- @Title:
- @Time: 2024/3/11
- @Author: Michael Jie
- """
-
- import sklearn.datasets as ds
- from sklearn.model_selection import train_test_split
-
- # 加载数据
- diabetes = ds.load_diabetes(scaled=False)
-
- # 将数据集进行80%训练集和20%的测试集的分割
- x_train, x_test, y_train, y_test = train_test_split(
- diabetes.data, diabetes.target, test_size=0.2, random_state=0
- )
- print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)
- """
- (353, 10) (89, 10) (353,) (89,)
- """
- # 创建基本线性回归类
- linear = LinearRegression(
- # 是否计算截距
- fit_intercept=True,
- # 是否拷贝特征集
- copy_X=True,
- )
-
- # 创建正则线性回归类
- ridge = Ridge(
- # 学习率
- alpha=1.0,
- # 是否计算截距
- fit_intercept=True,
- # 是否拷贝特征集
- copy_X=True,
- # 最大训练轮次
- max_iter=None,
- # 最小损失差
- tol=1e-4,
- )
- """
- @Title: 训练模型和评估
- @Time: 2024/3/11
- @Author: Michael Jie
- """
-
- import sklearn.datasets as ds
- from sklearn.linear_model import LinearRegression, Ridge
- from sklearn.model_selection import train_test_split
-
- # 加载数据
- diabetes = ds.load_diabetes(scaled=True)
-
- # 将数据集进行80%的训练集和20%的测试集的分割
- x_train, x_test, y_train, y_test = train_test_split(
- diabetes.data, diabetes.target, test_size=0.2, random_state=0
- )
-
- # 创建基本线性回归类
- linear = LinearRegression()
- # 训练
- linear.fit(x_train, y_train)
- print(linear.coef_, linear.intercept_)
- """
- [ -35.55025079 -243.16508959 562.76234744 305.46348218 -662.70290089
- 324.20738537 24.74879489 170.3249615 731.63743545 43.0309307 ] 152.5380470138517
- """
-
- # 创建正则线性回归类
- ridge = Ridge()
- # 训练
- ridge.fit(x_train, y_train)
- print(ridge.coef_, ridge.intercept_)
- """
- [ 21.34794489 -72.97401935 301.36593604 177.49036347 2.82093648
- -35.27784862 -155.52090285 118.33395129 257.37783937 102.22540041] 151.9441509473086
- """
- # 创建基本线性回归类
- linear = LinearRegression()
- linear.fit(x_train, y_train)
- # 评估模型,结果在0-1之间,越大证明模型越拟合数据
- print(linear.score(x_test, y_test))
- """
- 0.33223321731061806
- """
-
- # 创建正则线性回归类
- ridge = Ridge()
- ridge.fit(x_train, y_train)
- # 评估模型
- print(ridge.score(x_test, y_test))
- """
- 0.3409800318493461
- """