百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

【Kaggle】模型调参利器 gridSearchCV(网格搜索)

nanshan 2024-10-12 05:41 25 浏览 0 评论

gridSearchCV(网格搜索)的参数、方法及示例

1.简介

GridSearchCV的sklearn官方网址:


http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV


GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数。但是这个方法适合于小数据集,一旦数据的量级上去了,很难得出结果。这个时候就是需要动脑筋了。数据量比较大的时候可以使用一个快速调优的方法——坐标下降。它其实是一种贪心算法:拿当前对模型影响最大的参数调优,直到最优化;再拿下一个影响最大的参数调优,如此下去,直到所有的参数调整完毕。这个方法的缺点就是可能会调到局部最优而不是全局最优,但是省时间省力,巨大的优势面前,还是试一试吧,后续可以再拿bagging再优化。


通常算法不够好,需要调试参数时必不可少。比如SVM的惩罚因子C,核函数kernel,gamma参数等,对于不同的数据使用不同的参数,结果效果可能差1-5个点,sklearn为我们提供专门调试参数的函数grid_search。


2.参数说明

class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch=‘2*n_jobs’, error_score=’raise’, return_train_score=’warn’)


(1)estimator


选择使用的分类器,并且传入除需要确定最佳的参数之外的其他参数。每一个分类器都需要一个scoring参数,或者score方法:estimator=RandomForestClassifier(min_samples_split=100,min_samples_leaf=20,max_depth=8,max_features='sqrt',random_state=10),


(2)param_grid


需要最优化的参数的取值,值为字典或者列表,例如:param_grid =param_test1,param_test1 = {'n_estimators':range(10,71,10)}。


(3)scoring=None


模型评价标准,默认None,这时需要使用score函数;或者如scoring='roc_auc',根据所选模型不同,评价准则不同。字符串(函数名),或是可调用对象,需要其函数签名形如:scorer(estimator, X, y);如果是None,则使用estimator的误差估计函数。具体值的选取看本篇第三节内容。


(4)fit_params=None


(5) n_jobs=1


n_jobs: 并行数,int:个数,-1:跟CPU核数一致, 1:默认值


(6) iid=True


iid:默认True,为True时,默认为各个样本fold概率分布一致,误差估计为所有样本之和,而非各个fold的平均。


(7) refit=True


默认为True,程序将会以交叉验证训练集得到的最佳参数,重新对所有可用的训练集与开发集进行,作为最终用于性能评估的最佳模型参数。即在搜索参数结束后,用最佳参数结果再次fit一遍全部数据集。


(8)cv=None


交叉验证参数,默认None,使用三折交叉验证。指定fold数量,默认为3,也可以是yield训练/测试数据的生成器。


(9)verbose=0, scoring=None


verbose:日志冗长度,int:冗长度,0:不输出训练过程,1:偶尔输出,>1:对每个子模型都输出。


(10)pre_dispatch=‘2*n_jobs’


指定总共分发的并行任务数。当n_jobs大于1时,数据将在每个运行点进行复制,这可能导致OOM,而设置pre_dispatch参数,则可以预先划分总共的job数量,使数据最多被复制pre_dispatch次


(11)error_score=’raise’


(12)return_train_score=’warn’


如果“False”,cv_results_属性将不包括训练分数


回到sklearn里面的GridSearchCV,GridSearchCV用于系统地遍历多种参数组合,通过交叉验证确定最佳效果参数。


3. Scoring parameter:评价标准参数详细说明

Model-evaluation tools using cross-validation (such as model_selection.cross_val_score andmodel_selection.GridSearchCV) rely on an internal scoring strategy. This is discussed in the section The scoring parameter: defining model evaluation rules.


For the most common use cases, you can designate a scorer object with the scoring parameter; the table below shows all possible values. All scorer objects follow the convention that higher return values are better than lower return values. Thus metrics which measure the distance between the model and the data, like metrics.mean_squared_error, are available as neg_mean_squared_error which return the negated value of the metric.



4.属性

(1)cv_results_ : dict of numpy (masked) ndarrays


具有键作为列标题和值作为列的dict,可以导入到DataFrame中。注意,“params”键用于存储所有参数候选项的参数设置列表。


(2)best_estimator_ : estimator


通过搜索选择的估计器,即在左侧数据上给出最高分数(或指定的最小损失)的估计器。 如果refit = False,则不可用。


(3)best_score_ : float best_estimator的分数


(4)best_params_ : dict 在保存数据上给出最佳结果的参数设置


(5)best_index_ : int 对应于最佳候选参数设置的索引(cv_results_数组)。

search.cv_results _ ['params'] [search.best_index_]中的dict给出了最佳模型的参数设置,给出了最高的平均分数(search.best_score_)。


(6)scorer_ : function


Scorer function used on the held out data to choose the best parameters for the model.


(7)n_splits_ : int


The number of cross-validation splits (folds/iterations).


(8)grid_scores_:给出不同参数情况下的评价结果


6.网格搜索实例

import pandas as pd # 数据科学计算工具

import numpy as np # 数值计算工具

import matplotlib.pyplot as plt # 可视化

import seaborn as sns # matplotlib的高级API

from sklearn.model_selection import StratifiedKFold #交叉验证

from sklearn.model_selection import GridSearchCV #网格搜索

from sklearn.model_selection import train_test_split #将数据集分开成训练集和测试集

from xgboost import XGBClassifier                     #xgboost

 

 

pima = pd.read_csv("pima_indians-diabetes.csv")

print(pima.head())

 

x = pima.iloc[:,0:8]

y = pima.iloc[:,8]

 

seed = 7 #重现随机生成的训练

test_size = 0.33 #33%测试,67%训练

X_train, X_test, Y_train, Y_test = train_test_split(x, y, test_size=test_size, random_state=seed

model = XGBClassifier()               

learning_rate = [0.0001,0.001,0.01,0.1,0.2,0.3] #学习率

gamma = [1, 0.1, 0.01, 0.001]

 

param_grid = dict(learning_rate = learning_rate,gamma = gamma)#转化为字典格式,网络搜索要求

 

kflod = StratifiedKFold(n_splits=10, shuffle = True,random_state=7)#将训练/测试数据集划分10个互斥子集,

 

grid_search = GridSearchCV(model,param_grid,scoring = 'neg_log_loss',n_jobs = -1,cv = kflod)

#scoring指定损失函数类型,n_jobs指定全部cpu跑,cv指定交叉验证

grid_result = grid_search.fit(X_train, Y_train) #运行网格搜索

print("Best: %f using %s" % (grid_result.best_score_,grid_search.best_params_))

#grid_scores_:给出不同参数情况下的评价结果。best_params_:描述了已取得最佳结果的参数的组合

#best_score_:成员提供优化过程期间观察到的最好的评分

#具有键作为列标题和值作为列的dict,可以导入到DataFrame中。

#注意,“params”键用于存储所有参数候选项的参数设置列表。

means = grid_result.cv_results_['mean_test_score']

params = grid_result.cv_results_['params']

for mean,param in zip(means,params):

    print("%f  with:   %r" % (mean,param))

结果如下:


参考:机器学习(四)--模型调参利器 gridSearchCV(网格搜索)


相关推荐

实战派 | Java项目中玩转Redis6.0客户端缓存

铺垫首先介绍一下今天要使用到的工具Lettuce,它是一个可伸缩线程安全的redis客户端。多个线程可以共享同一个RedisConnection,利用nio框架Netty来高效地管理多个连接。放眼望向...

轻松掌握redis缓存穿透、击穿、雪崩问题解决方案(20230529版)

1、缓存穿透所谓缓存穿透就是非法传输了一个在数据库中不存在的条件,导致查询redis和数据库中都没有,并且有大量的请求进来,就会导致对数据库产生压力,解决这一问题的方法如下:1、使用空缓存解决对查询到...

Redis与本地缓存联手:多级缓存架构的奥秘

多级缓存(如Redis+本地缓存)是一种在系统架构中广泛应用的提高系统性能和响应速度的技术手段,它综合利用了不同类型缓存的优势,以下为你详细介绍:基本概念本地缓存:指的是在应用程序所在的服务器内...

腾讯云国际站:腾讯云服务器如何配置Redis缓存?

本文由【云老大】TG@yunlaoda360撰写一、安装Redis使用包管理器安装(推荐)在CentOS系统中,可以通过yum包管理器安装Redis:sudoyumupdate-...

Spring Boot3 整合 Redis 实现数据缓存,你做对了吗?

你是否在开发互联网大厂后端项目时,遇到过系统响应速度慢的问题?当高并发请求涌入,数据库压力剧增,响应时间拉长,用户体验直线下降。相信不少后端开发同行都被这个问题困扰过。其实,通过在SpringBo...

【Redis】Redis应用问题-缓存穿透缓存击穿、缓存雪崩及解决方案

在我们使用redis时,也会存在一些问题,导致请求直接打到数据库上,导致数据库挂掉。下面我们来说说这些问题及解决方案。1、缓存穿透1.1场景一个请求进来后,先去redis进行查找,redis存在,则...

Spring boot 整合Redis缓存你了解多少

在前一篇里面讲到了Redis缓存击穿、缓存穿透、缓存雪崩这三者区别,接下来我们讲解Springboot整合Redis中的一些知识点:之前遇到过,有的了四五年,甚至更长时间的后端Java开发,并且...

揭秘!Redis 缓存与数据库一致性问题的终极解决方案

在现代软件开发中,Redis作为一款高性能的缓存数据库,被广泛应用于提升系统的响应速度和吞吐量。然而,缓存与数据库之间的数据一致性问题,一直是开发者们面临的一大挑战。本文将深入探讨Redis缓存...

高并发下Spring Cache缓存穿透?我用Caffeine+Redis破局

一、什么是缓存穿透?缓存穿透是指查询一个根本不存在的数据,导致请求直接穿透缓存层到达数据库,可能压垮数据库的现象。在高并发场景下,这尤其危险。典型场景:恶意攻击:故意查询不存在的ID(如负数或超大数值...

Redis缓存三剑客:穿透、雪崩、击穿—手把手教你解决

缓存穿透菜小弟:我先问问什么是缓存穿透?我听说是缓存查不到,直接去查数据库了。表哥:没错。缓存穿透是指查询一个缓存中不存在且数据库中也不存在的数据,导致每次请求都直接访问数据库的行为。这种行为会让缓存...

Redis中缓存穿透问题与解决方法

缓存穿透问题概述在Redis作为缓存使用时,缓存穿透是常见问题。正常查询流程是先从Redis缓存获取数据,若有则直接使用;若没有则去数据库查询,查到后存入缓存。但当请求的数据在缓存和数据库中都...

Redis客户端缓存的几种实现方式

前言:Redis作为当今最流行的内存数据库和缓存系统,被广泛应用于各类应用场景。然而,即使Redis本身性能卓越,在高并发场景下,应用于Redis服务器之间的网络通信仍可能成为性能瓶颈。所以客户端缓存...

Nginx合集-常用功能指导

1)启动、重启以及停止nginx进入sbin目录之后,输入以下命令#启动nginx./nginx#指定配置文件启动nginx./nginx-c/usr/local/nginx/conf/n...

腾讯云国际站:腾讯云怎么提升服务器速度?

本文由【云老大】TG@yunlaoda360撰写升级服务器规格选择更高性能的CPU、内存和带宽,以提供更好的处理能力和网络性能。优化网络配置调整网络接口卡(NIC)驱动,优化TCP/IP参数...

雷霆一击服务器管理员教程

本文转载莱卡云游戏服务器雷霆一击管理员教程(搜索莱卡云面版可搜到)首先你需要给服务器设置管理员密码,默认是空的管理员密码在启动页面进行设置设置完成后你需要重启服务器才可生效加入游戏后,点击键盘左上角E...

取消回复欢迎 发表评论: