• Rust的Linfa和Polars库进行机器学习


    使用Rust的Linfa库和Polars库来实现机器学习中的线性回归算法。
    Linfa crate旨在提供一个全面的工具包来使用Rust构建机器学习应用程序。
    Polars是Rust的一个DataFrame库,它基于Apache Arrow的内存模型。Apache arrow提供了非常高效的列数据结构,并且正在成为列数据结构事实上的标准。
    在下面的例子中,我们使用一个糖尿病数据集来训练线性回归算法
    使用以下命令创建一个Rust新项目:

    cargo new machine_learning_linfa
    
    • 1

    在Cargo.toml文件中加入以下依赖项:

    [dependencies]
    linfa = "0.7.0"
    linfa-linear = "0.7.0"
    ndarray = "0.15.6"
    polars = { version = "0.35.4", features = ["ndarray"]}
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在项目根目录下创建一个diabetes_file.csv文件,将数据集写入文件。

    AGE    SEX BMI BP  S1  S2  S3  S4  S5  S6  Y
    59    2   32.1    101 157 93.2    38  4   4.8598  87  151
    48    1   21.6    87  183 103.2   70  3   3.8918  69  75
    72    2   30.5    93  156 93.6    41  4   4.6728  85  141
    24    1   25.3    84  198 131.4   40  5   4.8903  89  206
    50    1   23  101 192 125.4   52  4   4.2905  80  135
    23    1   22.6    89  139 64.8    61  2   4.1897  68  97
    36    2   22  90  160 99.6    50  3   3.9512  82  138
    66    2   26.2    114 255 185 56  4.55    4.2485  92  63
    60    2   32.1    83  179 119.4   42  4   4.4773  94  110
    .............
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    数据集从这里下载:
    https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt

    在src/main.rs文件中写入以下代码:

    use linfa::prelude::*;
    use linfa::traits::Fit;
    use linfa_linear::LinearRegression;
    use ndarray::{ArrayBase, OwnedRepr};
    use polars::prelude::*; // Import polars
    
    fn main() -> Result<(), Box<dyn std::error::Error>> {
        // 将制表符定义为分隔符
        let separator = b'\t';
    
        let df = polars::prelude::CsvReader::from_path("./diabetes_file.csv")?
            .infer_schema(None)
            .with_separator(separator)
            .has_header(true)
            .finish()?;
    
        println!("{:?}", df);
    
        // 提取并转换目标列
        let age_series = df.column("AGE")?.cast(&DataType::Float64)?;
        let target = age_series.f64()?;
    
        println!("Creating features dataset");
    
        let mut features = df.drop("AGE")?;
    
        // 遍历列并将每个列强制转换为Float64
        for col_name in features.get_column_names_owned() {
            let casted_col = df
                .column(&col_name)?
                .cast(&DataType::Float64)
                .expect("Failed to cast column");
    
            features.with_column(casted_col)?;
        }
    
        println!("{:?}", df);
    
        let features_ndarray: ArrayBase<OwnedRepr<_>, _> =
            features.to_ndarray::<Float64Type>(IndexOrder::C)?;
        let target_ndarray = target.to_ndarray()?.to_owned();
        let (dataset_training, dataset_validation) =
            Dataset::new(features_ndarray, target_ndarray).split_with_ratio(0.80);
    
        // 训练模型
        let model = LinearRegression::default().fit(&dataset_training)?;
    
        // 预测
        let pred = model.predict(&dataset_validation);
    
        // 评价模型
        let r2 = pred.r2(&dataset_validation)?;
        println!("r2 from prediction: {}", r2);
    
        Ok(())
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57

    使用polar的CSV reader读取CSV文件。

    将数据帧打印到控制台以供检查。

    从DataFrame中提取“AGE”列作为线性回归的目标变量。将目标列强制转换为Float64(双精度浮点数),这是机器学习中数值数据的常用格式。

    将features DataFrame转换为narray::ArrayBase(一个多维数组)以与linfa兼容。将目标序列转换为数组,这些数组与用于机器学习的linfa库兼容。

    使用80-20的比例将数据集分割为训练集和验证集,这是机器学习中评估模型在未知数据上的常见做法。

    使用linfa的线性回归算法在训练数据集上训练线性回归模型。

    使用训练好的模型对验证数据集进行预测。

    计算验证数据集上的R²(决定系数)度量,以评估模型的性能。R²值表示回归预测与实际数据点的近似程度。

    执行cargo run,运行结果如下:

    shape: (442, 11)
    ┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐
    │ AGESEXBMIBP    ┆ … ┆ S4S5S6Y   │
    │ ------------   ┆   ┆ ------------ │
    │ i64i64f64f64   ┆   ┆ f64f64i64i64 │
    ╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡
    │ 59232.1101.0 ┆ … ┆ 4.04.859887151 │
    │ 48121.687.0  ┆ … ┆ 3.03.89186975  │
    │ 72230.593.0  ┆ … ┆ 4.04.672885141 │
    │ 24125.384.0  ┆ … ┆ 5.04.890389206 │
    │ …   ┆ …   ┆ …    ┆ …     ┆ … ┆ …    ┆ …      ┆ …   ┆ …   │
    │ 47224.975.0  ┆ … ┆ 5.04.4427102104 │
    │ 60224.999.67 ┆ … ┆ 3.774.127195132 │
    │ 36130.095.0  ┆ … ┆ 4.795.129985220 │
    │ 36119.671.0  ┆ … ┆ 3.04.59519257  │
    └─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘
    Creating features dataset
    shape: (442, 11)
    ┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐
    │ AGESEXBMIBP    ┆ … ┆ S4S5S6Y   │
    │ ------------   ┆   ┆ ------------ │
    │ i64i64f64f64   ┆   ┆ f64f64i64i64 │
    ╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡
    │ 59232.1101.0 ┆ … ┆ 4.04.859887151 │
    │ 48121.687.0  ┆ … ┆ 3.03.89186975  │
    │ 72230.593.0  ┆ … ┆ 4.04.672885141 │
    │ 24125.384.0  ┆ … ┆ 5.04.890389206 │
    │ …   ┆ …   ┆ …    ┆ …     ┆ … ┆ …    ┆ …      ┆ …   ┆ …   │
    │ 47224.975.0  ┆ … ┆ 5.04.4427102104 │
    │ 60224.999.67 ┆ … ┆ 3.774.127195132 │
    │ 36130.095.0  ┆ … ┆ 4.795.129985220 │
    │ 36119.671.0  ┆ … ┆ 3.04.59519257  │
    └─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘
    r2 from prediction: 0.15937814745521017
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    对于优先考虑快速迭代和快速原型的数据科学家来说,Rust的编译时间可能是令人头疼的问题。Rust的强静态类型系统虽然有利于确保类型安全和减少运行时错误,但也会在编码过程中增加一层复杂性。

  • 相关阅读:
    机器学习——K-means算法详解及python应用
    模块化---common.js
    你的编程能力从什么时候开始突飞猛进的?
    vue使用jsencrypt实现rsa前端加密
    ​孤网双机并联逆变器下垂控制策略(包括仿真模型,功率计算模块、下垂控制模块、电压电流双环控制模块​)(Simulink仿真)
    SpringBoot中HttpClient的使用
    AtCoder 265G 线段树
    怎么下载微信视频号视频?
    MATLAB读取每行文本并提取字符串后的数字
    移动中兴ZXHN F6610M光猫拨号密码查询
  • 原文地址:https://blog.csdn.net/weixin_43114209/article/details/136498905