您现在的位置是:首页 > 技术教程 正文

Python 训练集、测试集以及验证集切分方法:sklearn及手动切分

admin 阅读: 2024-03-23
后台-插件-广告管理-内容页头部广告(手机)

目录

方法一

方法二


需求目的:针对模型训练输入,按照6:2:2的比例进行训练集、测试集和验证集的划分。当前数据量约10万条。如果针对的是记录条数达上百万的数据集,可按照98:1:1的比例进行切分。

方法一:切分训练集和测试集,采用机器学习包sklearn中的train_test_split()函数
方法二:切分训练集、测试集以及验证集,针对dataframe手动切分

方法一

采用Sklearn包中的sklearn.model_selection.train_test_split()函数,该函数功能是将原始数据按照比例切分为训练集和测试集。

  1. 函数形式:
  2. sklearn.model_selection.train_test_split(*arrays, test_size=None,
  3. train_size=None, random_state=None, shuffle=True, stratify=None)
  4. 参数解读:
  5. *arrays:等长的列表、数组或者dataframe等
  6. test_size: 01之间,默认0.25
  7. train_size: 01之间,默认1
  8. random_state: 传递一个int值,以便在多个函数调用之间产生可复制的输出
  9. shuffle: 拆分前是否进行洗牌
  10. strafity: 是否对数据进行分层
  11. 返回结果:
  12. 输入序列的train test分割序列

例子

  1. >>> import numpy as np
  2. >>> from sklearn.model_selection import train_test_split
  3. >>> X, y = np.arange(10).reshape((5, 2)), range(5)
  4. >>> X
  5. array([[0, 1],
  6. [2, 3],
  7. [4, 5],
  8. [6, 7],
  9. [8, 9]])
  10. >>> list(y)
  11. [0, 1, 2, 3, 4]
  12. >>> X_train, X_test, y_train, y_test = train_test_split(
  13. ... X, y, test_size=0.33, random_state=42)
  14. ...
  15. >>> X_train
  16. array([[4, 5],
  17. [0, 1],
  18. [6, 7]])
  19. >>> y_train
  20. [2, 0, 3]
  21. >>> X_test
  22. array([[2, 3],
  23. [8, 9]])
  24. >>> y_test
  25. [1, 4]

方法二

手动切分,代码如下。输入采用Python的DataFrame,同样输出三个文件。如果需要每次都输入同样的切分数据,可采用random.seed()定义随机数种子。

  1. def split_train_test_valid():
  2. # read file
  3. input_path = "E:\\Data\\"
  4. file = "flow.csv"
  5. df_flow = pd.read_csv(input_path + file, header=None, encoding='gbk')
  6. # define the ratios 6:2:2
  7. train_len = int(len(df_flow) * 0.6)
  8. test_len = int(len(df_flow) * 0.2)
  9. # split the dataframe
  10. idx = list(df_flow.index)
  11. random.shuffle(idx) # 将index列表打乱
  12. df_train = df_flow.loc[idx[:train_len]]
  13. df_test = df_flow.loc[idx[train_len:train_len+test_len]]
  14. df_valid = df_flow.loc[idx[train_len+test_len:]] # 剩下的就是valid
  15. # output
  16. df_train.to_csv(input_path+'train.txt', header=False, index=False, sep='\t')
  17. df_test.to_csv(input_path+'test.txt', header=False, index=False, sep='\t')
  18. df_valid.to_csv(input_path+'valid.txt', header=False, index=False, sep='\t')

参考资料:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html 


 

标签:
声明

1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,请转载时务必注明文章作者和来源,不尊重原创的行为我们将追究责任;3.作者投稿可能会经我们编辑修改或补充。

在线投稿:投稿 站长QQ:1888636

后台-插件-广告管理-内容页尾部广告(手机)
关注我们

扫一扫关注我们,了解最新精彩内容

搜索
排行榜