百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术教程 > 正文

XGBoost你真的懂吗?我不信.....

csdh11 2025-02-09 11:56 11 浏览

选自Ancestry

作者:Tyler Folkman

编译:机器之心(almosthuman2014

原文:
https://blogs.ancestry.com/ancestry/2017/12/18/understanding-machine-learning-xgboost/

在这篇文章中,我们将介绍一些技术以更好地理解 XGBoost 的预测过程。这允许我们在利用 gradient boosting 的威力的同时,仍然能理解模型的决策过程。

为了解释这些技术,我们将使用 Titanic 数据集。该数据集有每个泰坦尼克号乘客的信息(包括乘客是否生还)。我们的目标是预测一个乘客是否生还,并且理解做出该预测的过程。即使是使用这些数据,我们也能看到理解模型决策的重要性。想象一下,假如我们有一个关于最近发生的船难的乘客数据集。建立这样的预测模型的目的实际上并不在于预测结果本身,但理解预测过程可以帮助我们学习如何最大化意外中的生还者。

  1. import pandas as pd

  2. from xgboost import XGBClassifier

  3. from sklearn.model_selection import train_test_split

  4. from sklearn.metrics import accuracy_score

  5. import operator

  6. import matplotlib.pyplot as plt

  7. import seaborn as sns

  8. import lime.lime_tabular

  9. from sklearn.pipeline import Pipeline

  10. from sklearn.preprocessing import Imputer

  11. import numpy as np

  12. from sklearn.grid_search import GridSearchCV

  13. %matplotlib inline


我们要做的首件事是观察我们的数据,你可以在 Kaggle 上找到(
https://www.kaggle.com/c/titanic/data)这个数据集。拿到数据集之后,我们会对数据进行简单的清理。即:

  • 清除名字和乘客 ID

  • 把分类变量转化为虚拟变量

  • 用中位数填充和去除数据

这些清洗技巧非常简单,本文的目标不是讨论数据清洗,而是解释 XGBoost,因此这些都是快速、合理的清洗以使模型获得训练。

  1. data = pd.read_csv("./data/titantic/train.csv")

  2. y = data.Survived

  3. X = data.drop(["Survived", "Name", "PassengerId"], 1)

  4. X = pd.get_dummies(X)


现在让我们将数据集分为训练集和测试集。

  1. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)


并通过少量的超参数测试构建一个训练管道。

  1. pipeline = Pipeline([( imputer , Imputer(strategy= median )), ( model , XGBClassifier())])


  1. parameters = dict(model__max_depth=[3, 5, 7],

  2. model__learning_rate=[.01, .1],

  3. model__n_estimators=[100, 500])



  4. cv = GridSearchCV(pipeline, param_grid=parameters)

  5. cv.fit(X_train, y_train)


接着查看测试结果。为简单起见,我们将会使用与 Kaggle 相同的指标:准确率。

  1. test_predictions = cv.predict(X_test)

  2. print("Test Accuracy: {}".format(accuracy_score(y_test, test_predictions)))

  3. Test Accuracy: 0.8101694915254237


至此我们得到了一个还不错的准确率,在 Kaggle 的大约 9000 个竞争者中排到了前 500 名。因此我们还有进一步提升的空间,但在此将作为留给读者的练习。

我们继续关于理解模型学习到什么的讨论。常用的方法是使用 XGBoost 提供的特征重要性(feature importance)。特征重要性的级别越高,表示该特征对改善模型预测的贡献越大。接下来我们将使用重要性参数对特征进行分级,并比较相对重要性。

  1. fi = list(zip(X.columns, cv.best_estimator_.named_steps[ model ].feature_importances_))

  2. fi.sort(key = operator.itemgetter(1), reverse=True)

  3. top_10 = fi[:10]

  4. x = [x[0] for x in top_10]

  5. y = [x[1] for x in top_10]


从上图可以看出,票价和年龄是很重要的特征。我们可以进一步查看生还/遇难与票价的相关分布:

我们可以很清楚地看到,那些生还者相比遇难者的平均票价要高得多,因此把票价当成重要特征可能是合理的。

特征重要性可能是理解一般的特征重要性的不错方法。假如出现了这样的特例,即模型预测一个高票价的乘客无法获得生还,则我们可以得出高票价并不必然导致生还,接下来我们将分析可能导致模型得出该乘客无法生还的其它特征。

这种个体层次上的分析对于生产式机器学习系统可能非常有用。考虑其它例子,使用模型预测是否可以某人一项贷款。我们知道信用评分将是模型的一个很重要的特征,但是却出现了一个拥有高信用评分却被模型拒绝的客户,这时我们将如何向客户做出解释?又该如何向管理者解释?

幸运的是,近期出现了华盛顿大学关于解释任意分类器的预测过程的研究。他们的方法称为 LIME,已经在 GitHub 上开源(
https://github.com/marcotcr/lime)。本文不打算对此展开讨论,可以参见论文(
https://arxiv.org/pdf/1602.04938.pdf)

接下来我们尝试在模型中应用 LIME。基本上,首先需要定义一个处理训练数据的解释器(我们需要确保传递给解释器的估算训练数据集正是将要训练的数据集):

  1. X_train_imputed = cv.best_estimator_.named_steps[ imputer ].transform(X_train)

  2. explainer = lime.lime_tabular.LimeTabularExplainer(X_train_imputed,

  3. feature_names=X_train.columns.tolist,

  4. class_names=["Not Survived", "Survived"],

  5. discretize_continuous=True)


随后你必须定义一个函数,它以特征数组为变量,并返回一个数组和每个类的概率:

  1. model = cv.best_estimator_.named_steps[ model ]

  2. def xgb_prediction(X_array_in):

  3. if len(X_array_in.shape) < 2:

  4. X_array_in = np.expand_dims(X_array_in, 0)

  5. return model.predict_proba(X_array_in)


最后,我们传递一个示例,让解释器使用你的函数输出特征数和标签:

  1. X_test_imputed = cv.best_estimator_.named_steps[ imputer ].transform(X_test)

  2. exp = explainer.explain_instance(X_test_imputed[1], xgb_prediction, num_features=5, top_labels=1)

  3. exp.show_in_notebook(show_table=True, show_all=False)


在这里我们有一个示例,76% 的可能性是不存活的。我们还想看看哪个特征对于哪个类贡献最大,重要性又如何。例如,在 Sex = Female 时,生存几率更大。让我们看看柱状图:

所以这看起来很有道理。如果你是女性,这就大大提高了你在训练数据中存活的几率。所以为什么预测结果是「未存活」?看起来 Pclass =2.0 大大降低了存活率。让我们看看:

看起来 Pclass 等于 2 的存活率还是比较低的,所以我们对于自己的预测结果有了更多的理解。看看 LIME 上展示的 top5 特征,看起来这个人似乎仍然能活下来,让我们看看它的标签:

  1. y_test.values[0]

  2. >>>1


这个人确实活下来了,所以我们的模型有错!感谢 LIME,我们可以对问题原因有一些认识:看起来 Pclass 可能需要被抛弃。这种方式可以帮助我们,希望能够找到一些改进模型的方法。

本文为读者提供了一个简单有效理解 XGBoost 的方法。希望这些方法可以帮助你合理利用 XGBoost,让你的模型能够做出更好的推断。

注:本文为机器之心编译,转载请联系本公众号获得授权。

相关推荐

探索Java项目中日志系统最佳实践:从入门到精通

探索Java项目中日志系统最佳实践:从入门到精通在现代软件开发中,日志系统如同一位默默无闻却至关重要的管家,它记录了程序运行中的各种事件,为我们排查问题、监控性能和优化系统提供了宝贵的依据。在Java...

用了这么多年的java日志框架,你真的弄懂了吗?

在项目开发过程中,有一个必不可少的环节就是记录日志,相信只要是个程序员都用过,可是咱们自问下,用了这么多年的日志框架,你确定自己真弄懂了日志框架的来龙去脉嘛?下面笔者就详细聊聊java中常用日志框架的...

物理老师教你学Java语言(中篇)(物理专业学编程)

第四章物质的基本结构——类与对象...

一文搞定!Spring Boot3 定时任务操作全攻略

各位互联网大厂的后端开发小伙伴们,在使用SpringBoot3开发项目时,你是否遇到过定时任务实现的难题呢?比如任务调度时间不准确,代码报错却找不到方向,是不是特别头疼?如今,随着互联网业务规模...

你还不懂java的日志系统吗 ?(java的日志类)

一、背景在java的开发中,使用最多也绕不过去的一个话题就是日志,在程序中除了业务代码外,使用最多的就是打印日志。经常听到的这样一句话就是“打个日志调试下”,没错在日常的开发、调试过程中打印日志是常干...

谈谈枚举的新用法--java(java枚举的作用与好处)

问题的由来前段时间改游戏buff功能,干了一件愚蠢的事情,那就是把枚举和运算集合在一起,然后运行一段时间后buff就出现各种问题,我当时懵逼了!事情是这样的,做过游戏的都知道,buff,需要分类型,且...

你还不懂java的日志系统吗(javaw 日志)

一、背景在java的开发中,使用最多也绕不过去的一个话题就是日志,在程序中除了业务代码外,使用最多的就是打印日志。经常听到的这样一句话就是“打个日志调试下”,没错在日常的开发、调试过程中打印日志是常干...

Java 8之后的那些新特性(三):Java System Logger

去年12月份log4j日志框架的一个漏洞,给Java整个行业造成了非常大的影响。这个事情也顺带把log4j这个日志框架推到了争议的最前线。在Java领域,log4j可能相对比较流行。而在log4j之外...

Java开发中的日志管理:让程序“开口说话”

Java开发中的日志管理:让程序“开口说话”日志是程序员的朋友,也是程序的“嘴巴”。它能让程序在运行过程中“开口说话”,告诉我们它的状态、行为以及遇到的问题。在Java开发中,良好的日志管理不仅能帮助...

吊打面试官(十二)--Java语言中ArrayList类一文全掌握

导读...

OS X 效率启动器 Alfred 详解与使用技巧

问:为什么要在Mac上使用效率启动器类应用?答:在非特殊专业用户的环境下,(每天)用户一般可以在系统中进行上百次操作,可以是点击,也可以是拖拽,但这些只是过程,而我们的真正目的是想获得结果,也就是...

Java中 高级的异常处理(java中异常处理的两种方式)

介绍异常处理是软件开发的一个关键方面,尤其是在Java中,这种语言以其稳健性和平台独立性而闻名。正确的异常处理不仅可以防止应用程序崩溃,还有助于调试并向用户提供有意义的反馈。...

【性能调优】全方位教你定位慢SQL,方法介绍下!

1.使用数据库自带工具...

全面了解mysql锁机制(InnoDB)与问题排查

MySQL/InnoDB的加锁,一直是一个常见的话题。例如,数据库如果有高并发请求,如何保证数据完整性?产生死锁问题如何排查并解决?下面是不同锁等级的区别表级锁:开销小,加锁快;不会出现死锁;锁定粒度...

看懂这篇文章,你就懂了数据库死锁产生的场景和解决方法

一、什么是死锁加锁(Locking)是数据库在并发访问时保证数据一致性和完整性的主要机制。任何事务都需要获得相应对象上的锁才能访问数据,读取数据的事务通常只需要获得读锁(共享锁),修改数据的事务需要获...