机器学习——需求预测+PCA+随机森林算法+shap可解释性分析+多模型性能对比

news/2025/2/26 22:36:37

1 数据集介绍

  1. 自行车共享租赁过程与环境和季节设置高度相关。例如,天气状况、降水、星期几、季节、一天中的小时等都会影响租赁行为。
  2. 数据集特征
  • instant:记录索引

  • dteday:日期

  • season:季节(1:春季,2:夏季,3:秋季,4:冬季)

  • yr:年份(0:2011 年,1:2012 年)

  • mnth:月份(1 到 12)

  • hr:小时(0 到 23)

  • holiday:是否为假日(从 DC 政府假日安排 提取)

  • weekday:星期几

  • workingday:如果一天既不是周末也不是假日,则为 1,否则为 0

  • weathersit

    • 1:晴朗、少量云、部分多云

    • 2:雾 + 多云、雾 + 破碎云、雾 + 少量云、雾

    • 3:小雪、小雨 + 雷暴 + 散云、小雨 + 散云

    • 4:大雨 + 冰雹 + 雷暴 + 雾、雪 + 雾

  • temp:按摄氏度归一化的温度。值除以 41(最大值)

  • atemp:按摄氏度归一化的体感温度。值除以 50(最大值)

  • hum:按 100(最大值)归一化的湿度

  • windspeed:按 67(最大值)归一化的风速

  • casual:临时用户数量

  • registered:注册用户数量

  • cnt:包括临时用户和注册用户的总租赁自行车数量


2.整体思路

  • 数据清洗

    移除冗余字段,处理缺失值。

  • 特征工程

    对分类变量进行字符映射,增强可读性。

  • EDA

    通过可视化分析变量分布及关联性。

  • 预处理

    独热编码、数据拆分、归一化。

  • 降维

    使用 PCA 减少特征维度。

  • 模型训练

    应用多种回归模型,评估性能。

  • 模型解释

    通过 SHAP 值解析特征影响。


3. 代码解析

3.1. 库导入与数据加载
# 基础数据处理与数学计算库
import numpy as np
import pandas as pd

# 数据可视化库
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(rc={'axes.facecolor':'#c5c6c7'})  # 设置图表背景色

# 机器学习工具库
from sklearn.model_selection import train_test_split  # 数据拆分
from sklearn.preprocessing import MinMaxScaler  # 数据归一化
from sklearn.decomposition import PCA, IncrementalPCA  # 主成分分析
from sklearn.metrics import r2_score  # 模型评估

# 忽略警告信息
import warnings
warnings.filterwarnings("ignore")

# 加载数据并展示前5行(带颜色渐变效果)
df = pd.read_csv(r"2025-2-24-公众号python机器学习ml-predicting-rider-count.csv")
df.head().style.background_gradient(cmap='bone')  # cmap指定渐变色谱

# 打印数据维度
df.shape  # 返回 (行数, 列数)

关键点说明

  • sns.set()

     统一调整 seaborn 图表样式,#c5c6c7 是十六进制灰色代码。

  • background_gradient

     通过颜色深浅直观显示数值大小。


3.2. 数据清洗
# 删除冗余列
cols_to_drop = ['dteday','instant','casual','registered']  # 待删除列名列表
df.drop(columns=cols_to_drop, axis=1, inplace=True)  # axis=1表示按列删除

# 分类变量数值映射(增强可解释性)
mappings = {
    'season': {1:"Winter", 2:"Spring", 3:"Summer", 4:"Fall"},
    'weekday': {0:"Saturday", 1:"Sunday", 2:"Monday", 3:"Tuesday", 
                4:"Wednesday", 5:"Thursday", 6:"Friday"},
    'mnth': {1:"Jan",2:"Feb",3:"Mar",4:"Apr",5:"May",6:"Jun",
             7:"Jul",8:"Aug",9:"Sep",10:"Oct",11:"Nov",12:"Dec"},
    'weathersit': {1:"Clear",2:"Cloudy",3:"LightRain",4:"Snow_Thunderstorm"}
}

for col, map_dict in mappings.items():
    df[col] = df[col].map(map_dict)  # 将数字编码转为可读字符串

# 检查数据基本信息
df.info()  # 显示各列的非空计数、数据类型,快速定位缺失值

代码优化说明

  • 使用字典 mappings 统一管理映射关系,提高代码可维护性。

  • df.info()

     可快速发现数据缺失情况(若输出显示某列非空计数小于总行数,则存在缺失值)。


3.3. 探索性分析 (EDA)

分类变量条形图

# 提取数值型与字符型列名
num_cols = list(df.select_dtypes(["int64","float64"]))  # 数值列
cat_cols = list(df.select_dtypes("object"))  # 分类列

# 分类变量可视化
fig, axes = plt.subplots(nrows=4, figsize=(16,28))  # 创建4个子图

colors = ('#17252a', '#2b7a78','#3aafa9','#def2f1','#feffff')  # 自定义颜色序列

for idx, col inenumerate(['season','mnth','weekday','weathersit']):
    ax = axes[idx]
    sns.barplot(
        x=df[col], 
        y=df['cnt'], 
        ax=ax, 
        palette=colors, 
        edgecolor="#c5c6c7", 
        ci=95# 显示95%置信区间
    )
    ax.set_ylabel(col, fontsize=16)  # 设置y轴标签为分类变量名
    ax.tick_params(left=False, labelleft=False)  # 隐藏左侧刻度
    ax.bar_label(ax.containers[0], fmt='%.0f', size=12)  # 在柱子上显示均值

plt.tight_layout()  # 自动调整子图间距
plt.show()

图片

数值变量回归图

plt.figure(figsize=(16,12))

# 绘制四个子图:temp, atemp, hum, windspeed 与 cnt 的关系
features = ['temp', 'atemp', 'hum', 'windspeed']
colors = ["#00008B", "#006400", "#800080", "#8B4513"]  # 海军蓝、深绿、紫色、棕色

for i, (feat, color) inenumerate(zip(features, colors)):
    plt.subplot(2,2,i+1)
    sns.regplot(
        x=feat, 
        y='cnt', 
        data=df, 
        color=color,
        scatter_kws={'alpha':0.3},  # 散点透明度
        line_kws={'lw':2}  # 回归线粗细
    )
    plt.xlabel(feat.capitalize())

plt.tight_layout()
plt.show()

图片

热力图分析相关性

plt.figure(figsize=(16,8))
sns.heatmap(
    df.corr(), 
    cmap="viridis",  # 颜色映射
    annot=True,  # 显示数值
    fmt=".2f",  # 数值格式化为两位小数
    annot_kws={'size':12}  # 数值标签字号
)
plt.xticks(size=14)
plt.yticks(size=14, rotation=0)  # 保持y轴标签水平
plt.title("Feature Correlation Matrix", pad=20, fontsize=18)
plt.show()

图片

关键分析点

  • 高相关性特征(如 temp 与 atemp)可能需合并或剔除,避免多重共线性。

  • 通过回归图斜率判断特征与目标变量的正/负相关关系。


3.4. 数据预处理

独热编码

# 对每个分类变量生成哑变量(drop_first=True 避免多重共线性)
for col in cat_cols:
    dummies = pd.get_dummies(df[col], prefix=col, drop_first=True)
    df = pd.concat([df, dummies], axis=1)
    df.drop(col, axis=1, inplace=True)  # 删除原始分类列

数据集拆分

y = df.pop('cnt')  # 提取目标变量
X = df  # 特征矩阵
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    train_size=0.7,  # 70%训练集
    random_state=69  # 固定随机种子保证可复现性
)

归一化处理

scaler = MinMaxScaler()  # 初始化归一化器
X_train_scaled = scaler.fit_transform(X_train)  # 训练集拟合并转换
X_test_scaled = scaler.transform(X_test)  # 测试集仅转换(避免数据泄露)

3.5. 主成分分析 (PCA)

累计方差图确定主成分数

pca = PCA(random_state=69)
pca.fit(X_train_scaled)

# 计算累计方差贡献率
explained_variance = pca.explained_variance_ratio_ * 100
cumulative_variance = np.cumsum(explained_variance)

plt.figure(figsize=(10,6))
plt.plot(range(1, len(cumulative_variance)+1), cumulative_variance, 
         marker='o', color='#9B1D20', markersize=8)
plt.axvline(17, color='red', linestyle='--', label='17 Components')
plt.axhline(92, color='blue', linestyle='--', label='92% Variance')
plt.xlabel("Number of Principal Components", fontsize=12)
plt.ylabel("Cumulative Explained Variance (%)", fontsize=12)
plt.legend()
plt.grid(axis='y')
plt.show()

图片

降维处理

# 使用增量PCA(适合大数据集)
pca_final = IncrementalPCA(n_components=17)
X_train_pca = pca_final.fit_transform(X_train_scaled)
X_test_pca = pca_final.transform(X_test_scaled)

3.6. 模型训练与评估
# 定义评估函数
defevaluate_model(model, X_train, y_train, X_test, y_test):
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    return {
        'MSE': mean_squared_error(y_test, y_pred),
        'MAE': mean_absolute_error(y_test, y_pred),
        'R²': r2_score(y_test, y_pred)
    }

# 初始化模型
models = {
    "Linear Regression": LinearRegression(),
    "Support Vector Regression": SVR(kernel='rbf'),
    "Random Forest": RandomForestRegressor(n_estimators=100, random_state=123),
    "XGBoost": XGBRegressor(n_estimators=100, random_state=123)
}

# 训练并评估模型
results = {}
for name, model in models.items():
    results[name] = evaluate_model(model, X_train_pca, y_train, X_test_pca, y_test)

# 结果转DataFrame并打印
results_df = pd.DataFrame(results).T
print(results_df.round(3))  # 保留3位小数

# 绘制评估指标对比图
results_df.plot(kind='bar', subplots=True, layout=(1,3), 
                figsize=(18,6), legend=False, edgecolor='black')
plt.suptitle("Model Performance Comparison", y=1.02)
plt.tight_layout()
plt.show()

图片

关键点说明

  • evaluate_model

     函数封装了训练和评估流程,提升代码复用性。

  • XGBoost

     需注意参数 n_estimators(树的数量)对性能的影响。


3.7. 模型解释 (SHAP)
 
# 初始化随机森林解释器
rf_model = models["Random Forest"]
explainer = shap.TreeExplainer(rf_model)

# 计算SHAP值(使用降维后的测试数据)
shap_values = explainer.shap_values(X_test_pca)

# 特征重要性总览(需传递原始特征名)
shap.summary_plot(
    shap_values, 
    features=X_test_pca, 
    feature_names=X.columns,  # 原始特征名
    plot_type='bar'
)

# 具体特征依赖图(示例:temp)
shap.dependence_plot(
    "temp", 
    shap_values, 
    X_test_pca, 
    feature_names=X.columns,
    interaction_index=None# 不显示交互效应
)

 

图片

图片

图片


http://www.niftyadmin.cn/n/5869237.html

相关文章

YOLOv11-ultralytics-8.3.67部分代码阅读笔记-trainer.py

trainer.py ultralytics\engine\trainer.py 目录 trainer.py 1.所需的库和模块 2.class BaseTrainer: 1.所需的库和模块 # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license """ Train a model on a dataset.Usage:$ yolo …

如何实现在Redis集群情况下,同一类数据固定保存在同一个Redis实例中

1. 使用哈希标签(Hash Tags) 概述 Redis Cluster使用一致性哈希算法来分配数据到不同的节点上。为了确保相同类型的数据被分配到同一个Redis实例上,可以利用哈希标签(Hash Tags)。哈希标签是指在键名中用花括号 {} 包…

LabVIEW Browser.vi 库说明

browser.llb 库位于C:\Program Files (x86)\National Instruments\LabVIEW 2019\vi.lib\Platform目录,它是 LabVIEW 平台下用于与网络浏览器相关操作的重要库。该库为 LabVIEW 开发者提供了一系列工具,用于实现网页浏览控制、网页数据获取与交互等功能&a…

《AI 大模型 ChatGPT 的传奇》

《AI 大模型 ChatGPT 的传奇》 ——段方 某世界 100 强企业大数据/AI 总设计师 教授 北京大学博士后 助理 :1三6三二四61四五4 1 AI 大模型的概念和特点 1.1 什么是”大模型、多模态“? 1.2 大模型带来了什么? 1.3 大模型为什么能产生质变&am…

2.25力扣-回溯组合总和

39. 组合总和 - 力扣&#xff08;LeetCode&#xff09; 一&#xff1a;Java class Solution {List<List<Integer>> ansnew LinkedList<>();List<Integer> tempnew LinkedList<>();int sum0;public List<List<Integer>> combinatio…

Go红队开发—基础语法入门

文章目录 基础语法语法框架数据类型类型转换变量var定义常量iota 枚举数组切片 结构体结构体方法 指针map类型转换导入包字符串strings包字符拼接ContainsReplace更多函数解释 输入输出字符串格式化fmt&#xff1a;Scanf、Scan、ScanlnScanfScanScanln fmt&#xff1a;Println、…

【Python爬虫(50)】从0到1:打造分布式爬虫项目全攻略

【Python爬虫】专栏简介&#xff1a;本专栏是 Python 爬虫领域的集大成之作&#xff0c;共 100 章节。从 Python 基础语法、爬虫入门知识讲起&#xff0c;深入探讨反爬虫、多线程、分布式等进阶技术。以大量实例为支撑&#xff0c;覆盖网页、图片、音频等各类数据爬取&#xff…

【Spring详解六】容器的功能扩展-ApplicationContext

六、容器的功能扩展_ApplicationContext 经过前几章的分析&#xff0c;对Spring中的容器功能有了简单的了解&#xff0c;在前面几章中一直以 BeanFactory接口以及它的默认实现类XmlBeanFacotory为例进行分析&#xff0c;但是Spring中还提供了 另一个接口ApplicationContext&…