LightGBM算法背景、原理、特点+Python实战案例

人工智能
为了解决这些问题,Microsoft在2017年推出了LightGBM(Light Gradient Boosting Machine),一个更快速、更低内存消耗、更高性能的梯度提升框架。

大家好,我是Peter~

今天给大家分享一下树模型的经典算法:LightGBM,介绍算法产生的背景、原理和特点,最后提供一个基于LightGBM和随机搜索调优的案例。

LightGBM算法

在机器学习领域,梯度提升机(Gradient Boosting Machines, GBMs)是一类强大的集成学习算法,它们通过逐步添加弱学习器(通常是决策树)来最小化预测误差,从而构建一个强大的模型。

在大数据时代,数据集的规模急剧增长,传统的GBMs由于其计算和存储成本高昂,难以有效地扩展。

  • 例如,对于水平分割的决策树生长策略,虽然可以生成平衡的树,但往往会导致模型的区分能力下降;而对于基于叶子的生长策略,虽能提高精度却容易过拟合。
  • 此外,大多数GBM实现在每次迭代中都需要遍历整个数据集来计算梯度,这在数据量巨大时效率低下。因此,需要一个既能高效处理大规模数据又能保持模型准确度的算法。

为了解决这些问题,Microsoft在2017年推出了LightGBM(Light Gradient Boosting Machine),一个更快速、更低内存消耗、更高性能的梯度提升框架。

官方学习地址:https://lightgbm.readthedocs.io/en/stable/

LightGBM的原理

1、基于直方图的决策树算法:

  • 原理:LightGBM使用直方图优化技术,将连续的特征值离散化成特定的bin(即直方图的桶),减少了在节点分裂时需要计算的数据量。
  • 优点:这种方法可以在减少内存使用的同时,提高计算速度。
  • 实现细节:对于每个特征,算法都维护一个直方图,记录该特征在不同分桶中的统计信息。在进行节点分裂时,可以直接利用这些直方图的信息,而不需要遍历所有数据。

2、带深度限制的leaf-wise树生长策略:

  • 原理:与传统的水平分割不同,leaf-wise的生长策略是每次从当前所有叶子节点中选择分裂收益最大的节点进行分裂。
  • 优点:这种策略可以使得决策树更加侧重于数据中的异常部分,通常可以得到更好的精度。
  • 缺点:容易导致过拟合,特别是当数据中有噪声时。
  • 改进措施:LightGBM通过设置最大深度限制来防止过拟合。

3、单边梯度采样(GOSS):

  • 原理:对于数据集中的大梯度样本,GOSS算法只保留数据的一部分(通常是大梯度的样本),减少计算量同时保证不会损失太多的信息。
  • 优点:这种方法可以在不显著损失精度的情况下加快训练速度。
  • 应用场景:特别适用于数据倾斜严重的情况。

4、互斥特征捆绑(EFB):

  • 原理:EFB是一种减少特征数量,提高计算效率的技术。它将互斥的特征(即从不同时为非零的特征)进行合并,以减少特征维度。
  • 优点:提高了内存的使用效率和训练速度。
  • 实现细节:通过特征的互斥性,算法可以在同一时间处理更多的特征,从而减少了实际处理的特征数。

5、支持并行和分布式学习:

  • 原理:LightGBM支持多线程学习,能够利用多个CPU进行并行训练。
  • 优点:显著提高了在多核处理器上的训练速度。
  • 扩展性:还支持分布式学习,可以利用多台机器共同训练模型。

6、缓存优化:

  • 原理:优化了对数据的读取方式,可以使用更多的缓存来加快数据交换的速度。
  • 优点:特别是在大数据集上,缓存优化可以显著提升性能。

7、支持多种损失函数:

  • 特点:除了常用的回归和分类的损失函数外,LightGBM还支持自定义损失函数,满足不同的业务需求。

8、正则化和剪枝:

  • 原理:提供了L1和L2正则化项来控制模型复杂度,避免过拟合。
  • 实现:实现了后向剪枝的策略来进一步防止过拟合。

9、模型解释性:

  • 特点:由于是基于决策树的模型,LightGBM具有良好的模型解释性,可以通过特征重要性等方式理解模型的决策逻辑。

LightGBM的特点

高效性

  • 速度优势:通过直方图优化和 leaf-wise 生长策略,LightGBM 在保证精度的同时大幅提升了训练速度。
  • 内存使用:相比于其他GBM实现,LightGBM 需要的内存更少,这使得它能够处理更大的数据集。

准确性

  • 最佳优先的生长策略:LightGBM 采用的 leaf-wise 生长策略可以更紧密地拟合数据,通常可以得到比水平分割更好的精度。
  • 避免过拟合的方法:通过设置最大深度限制和后向剪枝,LightGBM 能够在提升模型精度的同时避免过拟合。

可扩展性

  • 并行和分布式学习:LightGBM 的设计支持多线程和分布式计算,这使得它能够充分利用现代硬件的计算能力。
  • 多平台支持:LightGBM 可以在 Windows、macOS 和 Linux 等多种操作系统上运行,支持 Python、R、Java 等多种编程语言。

易用性

  • 参数调优:LightGBM 提供了丰富的参数选项,方便用户根据具体问题进行调整。
  • 预训练模型:用户可以从预训练的模型开始,加速自己的建模过程。
  • 模型解释工具:LightGBM 提供了特征重要性评估工具,帮助用户理解模型的决策过程。

导入库

In [1]:

import numpy as np

import lightgbm as lgb
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings("ignore")

加载数据

加载公开的iris数据集:

In [2]:

# 加载数据集
data = load_iris()
X, y = data.data, data.target
y = [int(i) for i in y]  # 将标签转换为整数

In [3]:

X[:3]

Out[3]:

array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2]])

In [4]:

y[:10]

Out[4]:

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

划分数据

In [5]:

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

同时创建LightGBM数据集:

In [6]:

lgb_train = lgb.Dataset(X_train, label=y_train)

参数设置

In [7]:

# 设置参数范围
param_dist = {
    'boosting_type': ['gbdt', 'dart'],  # 提升类型  梯度提升决策树(gbdt)和Dropouts meet Multiple Additive Regression Trees(dart)
    'objective': ['binary', 'multiclass'],  # 目标;二分类和多分类
    'num_leaves': range(20, 150),  # 叶子节点数量
    'learning_rate': [0.01, 0.05, 0.1],  # 学习率
    'feature_fraction': [0.6, 0.8, 1.0],  # 特征采样比例
    'bagging_fraction': [0.6, 0.8, 1.0],  # 数据采样比例
    'bagging_freq': range(0, 80),  # 数据采样频率
    'verbose': [-1]  # 是否显示训练过程中的详细信息,-1表示不显示
}

随机搜索调参

In [8]:

# 初始化模型
model = lgb.LGBMClassifier()


# 使用随机搜索进行参数调优
random_search = RandomizedSearchCV(estimator=model,
                                   param_distributinotallow=param_dist, # 参数组合
                                   n_iter=100, 
                                   cv=5, # 5折交叉验证
                                   verbose=2, 
                                   random_state=42, 
                                   n_jobs=-1)
# 模型训练
random_search.fit(X_train, y_train)
Fitting 5 folds for each of 100 candidates, totalling 500 fits

输出最佳的参数组合:

In [9]:

# 输出最佳参数
print("Best parameters found: ", random_search.best_params_)
Best parameters found:  {'verbose': -1, 'objective': 'multiclass', 'num_leaves': 87, 'learning_rate': 0.05, 'feature_fraction': 0.6, 'boosting_type': 'gbdt', 'bagging_freq': 22, 'bagging_fraction': 0.6}

使用最佳参数建模

In [10]:

# 使用最佳参数训练模型
best_model = random_search.best_estimator_
best_model.fit(X_train, y_train)

# 预测
y_pred = best_model.predict(X_test)
y_pred = [round(i) for i in y_pred]  # 将概率转换为类别

# 评估模型
print('Accuracy: %.4f' % accuracy_score(y_test, y_pred))
Accuracy: 0.9667

责任编辑:武晓燕 来源: 尤而小屋
相关推荐

2023-04-11 08:00:00

PythonOtsu阈值算法图像背景分割

2024-06-06 10:08:32

2011-05-18 11:14:45

JSP

2009-07-07 17:30:58

JSP应用开发

2017-05-26 11:00:38

Python算法

2010-09-16 14:42:44

JVM

2014-10-16 14:51:42

RFID

2010-09-26 08:50:11

JVM工作原理

2009-07-09 14:01:22

JVM工作原理

2010-06-24 13:55:41

LEACH协议

2023-04-26 06:22:45

NLPPython知识图谱

2010-09-17 15:32:52

JVM工作原理

2021-06-25 17:47:12

腾讯NTA

2021-06-25 17:41:35

腾讯NTA

2021-06-25 18:40:33

主机安全

2021-06-25 17:45:25

腾讯NTA

2012-11-08 11:17:24

VDIvSphere

2024-11-11 16:55:54

2024-09-06 17:57:35

2010-10-15 09:24:32

无线网络原理
点赞
收藏

51CTO技术栈公众号