Skip to content

配置文件 config.json 详解

config.json 是 GPTFF 模型训练的核心配置文件,用于精确控制训练过程的各项参数。文件采用 JSON 格式,分为 trainingdata 两个主要部分。

一、training 部分:训练参数

这部分包含了与模型训练过程直接相关的参数,如网络结构、优化器设置、硬件配置等。

  • workers:

    • 说明: 数据加载时使用的工作线程数。增加此值可以加快数据预处理速度,但会消耗更多CPU和内存。
    • 类型: integer
    • 示例: 20
  • epochs:

    • 说明: 训练的总轮数。一个 epoch 表示模型完整遍历一次整个训练数据集。
    • 类型: integer
    • 示例: 1000
  • batch_size:

    • 说明: 每批次训练的样本数量。此值影响内存(特别是GPU显存)使用和训练速度。如果遇到内存不足的错误,应减小此值。
    • 类型: integer
    • 示例: 32
  • learning_rate:

    • 说明: 学习率,控制模型参数更新的步长。是训练中最重要的超参数之一,对收敛速度和最终性能有决定性影响。
    • 类型: float
    • 示例: 1e-3
  • weight_decay:

    • 说明: 权重衰减(L2正则化)系数。用于防止模型过拟合,通过惩罚过大的权重值来实现。
    • 类型: float
    • 示例: 1e-4
  • node_feature_len:

    • 说明: 图神经网络中节点特征向量的维度。
    • 类型: integer
    • 示例: 64
  • edge_feature_len:

    • 说明: 图神经网络中边特征向量的维度。
    • 类型: integer
    • 示例: 64
  • n_layers:

    • 说明: 图神经网络的层数。增加层数可以提升模型的表达能力,但也可能导致过拟合和梯度消失问题。
    • 类型: integer
    • 示例: 3
  • n_readout_layers:

    • 说明: 读出(Readout)阶段的神经网络层数,用于从图表示中预测最终的能量、力等属性。
    • 类型: integer
    • 示例: 3
  • warmup_steps:

    • 说明: 在训练初期,学习率从一个较低的值线性增加到设定值的步数。有助于训练初期的稳定性。
    • 类型: integer
    • 示例: 0 (表示不使用warmup)
  • device:

    • 说明: 指定训练使用的计算设备。可选值为 "cuda" (使用NVIDIA GPU) 或 "cpu"
    • 类型: string
    • 示例: "cuda"
  • val_fold:

    • 说明: 指定交叉验证中用作验证集的折(fold)的编号。例如,如果数据被分为5折,此值可以是0到4。
    • 类型: integer
    • 示例: 0
  • resume:

    • 说明: 是否从上一个检查点(checkpoint)恢复训练。如果为 true,训练将从最后保存的状态继续。
    • 类型: boolean
    • 示例: false
  • transformer_activate:

    • 说明: 是否在模型中激活 Transformer 模块。这是一个实验性功能。
    • 类型: boolean
    • 示例: false
  • start_epoch:

    • 说明: 训练的起始轮数。主要在恢复训练时使用,以确保日志和学习率调度正确。
    • 类型: integer
    • 示例: 0
  • weight_energy:

    • 说明: 损失函数中能量项的权重。用于平衡不同物理量(能量、力、应力)在总损失中的贡献。
    • 类型: float
    • 示例: 1.0
  • weight_force:

    • 说明: 损失函数中力项的权重。
    • 类型: float
    • 示例: 1.0
  • weight_stress:

    • 说明: 损失函数中应力项的权重。
    • 类型: float
    • 示例: 1.0

二、data 部分:数据路径

这部分定义了训练数据的来源。

  • data_path:

    • 说明: 训练数据文件所在的目录路径。
    • 类型: string
    • 示例: "./" (表示当前目录)
  • data_file:

    • 说明: 训练数据集的CSV文件名。
    • 类型: string
    • 示例: "training_dataset.csv"