PyTorch 训练,除了会训练还要了解这些

开发
本文让我们讨论一下在训练过程中帮助你进行实验的技术。我将提供一些理论、代码片段和完整的流程示例。

让我们讨论一下在训练过程中帮助你进行实验的技术。我将提供一些理论、代码片段和完整的流程示例。主要要点包括:

  • 数据集分割
  • 指标
  • 可重复性
  • 配置、日志记录和可视化

分割数据集

我喜欢有训练集、验证集和测试集的分割。这里没什么好说的;你可以使用随机分割,或者如果你有一个不平衡的数据集(就像在实际情况中经常发生的那样)——分层分割。

对于测试集,尝试手动挑选一个“黄金数据集”,包含你希望模型擅长的所有示例。测试集应该在实验之间保持不变。它应该只在你完成模型训练后使用。这将在部署到生产环境之前给你客观的指标。别忘了,你的数据集应该尽可能接近生产环境,这样才有代表性。

指标

为你的任务选择正确的指标至关重要。我最喜欢的错误使用指标的例子是 Kaggle 的“深空系外行星狩猎”数据集,在那里你可以找到很多笔记本,人们在大约有 5000 个负样本和 50 个正样本的严重不平衡的数据集上使用准确率。当然,他们得到了 99% 的准确率,并且总是预测负样本。那样的话,他们永远也找不到系外行星,所以让我们明智地选择指标。

深入讨论指标超出了本文的范围,但我将简要提及一些可靠的选项:

  • F1 分数
  • 精确度和召回率
  • mAP(检测任务)
  • IoU(分割任务)
  • 准确率(对于平衡的数据集)
  • ROC-AUC

真实图像分类问题的分数示例:

+--------+----------+--------+-----------+--------+
| split  | accuracy |   f1   | precision | recall |
+--------+----------+--------+-----------+--------+
| val    | 0.9915   | 0.9897 | 0.9895    | 0.99   |
| test   | 0.9926   | 0.9921 | 0.9927    | 0.9915 |
+--------+----------+--------+-----------+--------+

为你的任务选择几个指标:

def get_metrics(gt_labels: List[int], preds: List[int]) -> Dict[str, float]:
    num_classes = len(set(gt_labels))
    if num_classes == 2:
        average = "binary"
    else:
        average = "macro"

    metrics = {}
    metrics["accuracy"] = accuracy_score(gt_labels, preds)
    metrics["f1"] = f1_score(gt_labels, preds, average=average)
    metrics["precision"] = precision_score(gt_labels, preds, average=average)
    metrics["recall"] = recall_score(gt_labels, preds, average=average)
    return metrics

此外,绘制精确度-阈值和召回率-阈值曲线,以更好地选择置信度阈值。

可重复性

没有可靠的可重复性,我们就不能谈论实验。当你不改变任何东西时,你应该得到相同的结果。简单的例子,如果你使用 torch 和 Nvidia,如何冻结所有种子:

def set_seeds(seed: int, cudnn_fixed: bool = False) -> None:
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if cudnn_fixed:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

注意:cudnn_fixed 可能会影响性能。我在实验期间使用它,然后在选择参数后的最终训练阶段关闭它。

这是当你有 cudnn 固定并且不改变参数时会发生什么——训练完全一样。这就是我们希望从 cudnn_fixed 得到的结果。

现在你可以处理参数,并确保结果的变化是因为参数的变化。

配置、日志记录和可视化

这是人们经常忘记的部分,但我发现它非常有用:

  • 包含变量和超参数的配置文件
  • 记录所有指标和配置的日志
  • 指标的可视化

当你有一个包含很多模块的项目时,配置文件非常方便,你可以将变量放在一个配置文件中,并在所有模块中使用。你还应该在那里存储训练配置。这里有一个我使用的 Hydra 配置的例子:


project_name: project_name
exp_name: baseline

exp: ${exp_name}_${now_dir}

train:
  root: /path/to/project
  device: cuda

  label_to_name: {0: "class_1", 1: "class_2", 2: "class_3"}
  img_size: [256, 256] # (h, w)

  train_split: 0.8
  val_split: 0.1 # test_split = 1 - train_split - val_split

  batch_size: 64
  epochs: 15
  use_scheduler: True

  layers_to_train: -1

  num_workers: 10
  threads_to_use: 10

  data_path: ${train.root}/dataset
  path_to_save: ${train.root}/output/models/${exp}
  vis_path: ${train.root}/output/visualized

  seed: 42
  cudnn_fixed: False
  debug_img_processing: False


export: # TensorRT must be done on the inference device
  half: False
  max_batch_size: 1

  model_path: ${train.path_to_save}
  path_to_data: ${train.root}/to_test

每次训练会议的配置文件和指标都应该被记录。通过 wandb(或类似的东西)集成,每次训练会议都被记录和可视化。

我也更喜欢保存在本地:

在训练期间:

  • 每个时代后打印验证指标
  • 如果它实现了最佳指标,则保存模型
  • 如果调试模式开启,则保存预处理图像

训练结束时:

  • 保存包含最佳验证和测试指标的 metrics.csv

训练后:

  • 保存 model.onnx,model.engine 和在模型导出期间创建的其他格式
  • 保存显示模型注意力的可视化

结构示例:

output
|
├── debug_img_processing
|   ├── img_1
|   └── img_2
|
├── models
|   ├── experiment_1
|       ├── model.pt
|       ├── model.engine
|       ├── precision_recall_curves
|           └── val_precision_recall_vs_threshold.png
|       └── metrics.csv
|
└── visualized
    ├── class_1
        ├── img_1
        └── img_2
责任编辑:赵宁宁 来源: 小白玩转Python
相关推荐

2021-09-15 09:51:36

数据库架构技术

2017-11-17 08:48:18

IOPSSSD性能

2022-10-26 07:21:15

网络视频开发

2021-03-25 15:19:33

深度学习Pytorch技巧

2023-02-19 15:26:51

深度学习数据集

2020-12-03 10:17:25

Kubernetes架构微服务

2020-08-03 12:47:58

DevOps数据科学家代码

2015-06-03 10:34:10

iOS 9苹果WWDC

2023-08-14 07:42:01

模型训练

2024-07-25 08:25:35

2015-08-12 15:12:56

黑客攻击云安全云服务

2021-03-15 12:00:19

Kubernetes微服务架构

2022-10-17 08:00:00

机器学习数据驱动数据科学

2020-07-13 14:30:35

人工智能机器学习技术

2021-02-03 13:22:53

区块链数据隐私

2023-05-25 21:35:00

稳定性建设前端

2019-09-18 17:35:52

2017-12-22 10:48:00

AI深度学习迁移学习

2020-10-27 09:37:43

PyTorchTensorFlow机器学习

2018-06-21 06:56:03

CASB云安全加密
点赞
收藏

51CTO技术栈公众号