目录
随机森林(Random Forest)是一种集成学习算法,它通过构建多个决策树并将它们的预测结果进行汇总来提高整体模型的预测准确率、稳定性和泛化能力。随机森林属于“bagging”(Bootstrap Aggregating)方法的一种实现,它结合了决策树的强大分类能力和集成学习的优势。
构建多棵决策树:随机森林通过自助抽样法(bootstrap sampling)从原始数据集中随机抽取多个样本集,每个样本集都是原始数据集的一个有放回抽样版本。然后,基于每个样本集独立地训练一棵决策树。由于是有放回抽样,原始数据集中的某些样本可能会在多个样本集中出现,而有些样本则可能一次都不出现。
随机选择特征:在构建每棵决策树的过程中,不是使用数据集中的所有特征来寻找最佳划分,而是随机选择一部分特征(通常是总特征数的一个子集)来进行节点划分。这种特征选择的随机性进一步增加了模型的多样性,有助于减少过拟合并提高模型的泛化能力。
集成预测:对于分类问题,随机森林中的每棵决策树都会给出一个预测结果(即类别的投票)。最终,随机森林的预测结果是所有决策树预测结果的众数(即出现次数最多的类别)。对于回归问题,则取所有决策树预测结果的平均值作为最终预测。
随机森林广泛应用于各种分类和回归任务中,包括但不限于:
TotalCharges
从字符串转换为浮点数。gender
, InternetService
, Contract
等)。数据集如下图所示:
- % 数据加载
- data = readtable('WA_Fn-UseC_-Telco-Customer-Churn.csv');
-
- % 转换二分类特征为数值型
- data.Churn = strcmp(data.Churn, 'Yes'); % 'Yes'为1,'No'为0
- data.gender = strcmp(data.gender, 'Male');
- data.Partner = strcmp(data.Partner, 'Yes');
- data.Dependents = strcmp(data.Dependents, 'Yes');
- data.PhoneService = strcmp(data.PhoneService, 'Yes');
- data.PaperlessBilling = strcmp(data.PaperlessBilling, 'Yes');
-
- % 填充缺失值
- data.TotalCharges(isnan(data.TotalCharges)) = 0;
-
- % % 独热编码列
- % categoricalVars = {'MultipleLines', 'OnlineSecurity','OnlineBackup','DeviceProtection','TechSupport','StreamingTV','StreamingMovies',...
- % 'InternetService', 'Contract', 'PaymentMethod'};
-
- % 提取特征列
- allFeatures = data{:, {'SeniorCitizen', 'tenure', 'MonthlyCharges', 'TotalCharges', 'gender', 'Partner', 'Dependents', 'PhoneService', 'PaperlessBilling'}};
-
- % 提取目标变量
- target = data.Churn;
- target = table(target, 'VariableNames', {'Churn'});
-
- % 划分数据集为训练集和测试集
- cv = cvpartition(height(allFeatures), 'HoldOut', 0.3); % 70%训练,30%测试
- idx = cv.test;
- XTrain = allFeatures(~idx, :);
- YTrain = target(~idx, :);
- XTest = allFeatures(idx, :);
- YTest = target(idx, :);
-
- % 训练随机森林模型回归预测
- rfModel = TreeBagger(50, XTrain, YTrain.Churn, 'Method', 'classification');
-
- % 预测
- YTestPredicted = predict(rfModel, XTest);
-
- % 评估模型
- YTestPredicted = str2double(YTestPredicted);
- accuracyRF = sum(YTestPredicted == YTest.Churn) / numel(YTest.Churn);
- fprintf('Random Forest Accuracy: %.2f%%\n', accuracyRF * 100);
-
-
- % 获取概率预测
- [~, scores] = predict(rfModel, XTest);
- % 绘制ROC曲线
- [X,Y,T,AUC] = perfcurve(YTest.Churn, scores(:,2), 1); % 假设scores(:,2)是正类的预测概率
- figure;
- plot(X,Y);
- xlabel('False positive rate'); ylabel('True positive rate');
- title(['ROC Curve, AUC = ', num2str(AUC)]);
- grid on;
-
- % 转换预测结果为逻辑值
- YTestPredicted_logical = logical(YTestPredicted);
-
- % 计算混淆矩阵
- confMat = confusionmat(YTest.Churn, YTestPredicted_logical);
-
- % 显示混淆矩阵
- figure;
- confusionchart(confMat, {'Not Churn', 'Churn'});
- title('Confusion Matrix - Random Forest');
-
- % 预测图
- figure;
- gscatter(XTest(:,1), XTest(:,2), YTestPredicted);
- xlabel('Feature 1');
- ylabel('Feature 2');
- title('Random Forest Predicted Classes');
- legend('Class 0', 'Class 1', 'Location', 'best');