第 5 章 使用 Training Operator 调整模型


要使用 Kubeflow Training Operator 调整模型,您需要配置和运行培训作业。

另外,您可以使用 Low-Rank Adaptation (LoRA)高效地微调大型语言模型,如 Llama 3。该集成可优化计算要求并减少内存占用量,从而允许对消费者级 GPU 进行微调。该解决方案结合了 PyTorch Fully Sharded Data Parallel (FSDP)和 LoRA 的组合,以启用可扩展、经济型模型的培训和认证,从而增强了 OpenShift 环境中 AI 工作负载的灵活性和性能。

5.1. 配置培训作业

在使用培训作业调优模型之前,您必须配置培训工作。本节中的培训工作示例基于 GitHub 中提供的 IBM 和 Hugging Face 调优示例。

先决条件

  • 您已登陆到 OpenShift。
  • 您可以访问配置为运行分布式工作负载的数据科学项目集群,如 管理分布式工作负载 中所述。
  • 您已创建了数据科学项目。有关如何创建项目的详情,请参考 创建数据科学项目
  • 您有数据科学项目的 Admin 访问权限。

    • 如果创建项目,则自动具有 Admin 访问权限。
    • 如果没有创建项目,您的集群管理员必须授予 Admin 访问权限。
  • 您可以访问模型。
  • 您可以访问可用于培训模型的数据。

流程

  1. 在一个终端窗口中,如果您还没有登录到 OpenShift 集群,请登录到 OpenShift CLI,如下例所示:

    $ oc login <openshift_cluster_url> -u <username> -p <password>
    Copy to Clipboard Toggle word wrap
  2. 配置培训工作,如下所示:

    1. 创建名为 config_trainingjob.yaml 的 YAML 文件。
    2. 添加 ConfigMap 对象定义,如下所示:

      training-job 配置示例

      kind: ConfigMap
      apiVersion: v1
      metadata:
        name: training-config
        namespace: kfto
      data:
        config.json: |
          {
            "model_name_or_path": "bigscience/bloom-560m",
            "training_data_path": "/data/input/twitter_complaints.json",
            "output_dir": "/data/output/tuning/bloom-twitter",
            "save_model_dir": "/mnt/output/model",
            "num_train_epochs": 10.0,
            "per_device_train_batch_size": 4,
            "per_device_eval_batch_size": 4,
            "gradient_accumulation_steps": 4,
            "save_strategy": "no",
            "learning_rate": 1e-05,
            "weight_decay": 0.0,
            "lr_scheduler_type": "cosine",
            "include_tokens_per_second": true,
            "response_template": "\n### Label:",
            "dataset_text_field": "output",
            "padding_free": ["huggingface"],
            "multipack": [16],
            "use_flash_attn": false
          }
      Copy to Clipboard Toggle word wrap

    3. 可选: 要使用 Low Rank Adaptation (LoRA)微调,请按如下所示更新 config.json 部分:

      1. peft_method 参数设置为 "lora "。
      2. 添加 lora_r,lora_alpha,lora_dropout,bias, 和 target_modules 参数。

        LoRA 配置示例

              ...
              "peft_method": "lora",
              "lora_r": 8,
              "lora_alpha": 8,
              "lora_dropout": 0.1,
              "bias": "none",
              "target_modules": ["all-linear"]
            }
        Copy to Clipboard Toggle word wrap

    4. 可选: 要使用 Quantized Low Rank Adaptation (QLoRA)微调,请按如下所示更新 config.json 部分:

      1. use_flash_attn 参数设置为 "true "。
      2. peft_method 参数设置为 "lora "。
      3. 添加 LoRA 参数: lora_r,lora_alpha,lora_dropout,bias, 和 target_modules
      4. 添加 QLoRA 强制参数: auto_gptqtorch_dtypefp16
      5. 如果需要,添加 QLoRA 可选参数: fused_lorafast_kernels

        QLoRA 配置示例

              ...
              "use_flash_attn": true,
              "peft_method": "lora",
              "lora_r": 8,
              "lora_alpha": 8,
              "lora_dropout": 0.1,
              "bias": "none",
              "target_modules": ["all-linear"],
              "auto_gptq": ["triton_v2"],
              "torch_dtype": float16,
              "fp16": true,
              "fused_lora": ["auto_gptq", true],
              "fast_kernels": [true, true, true]
            }
        Copy to Clipboard Toggle word wrap

    5. 按照下表所示,编辑 training-job 配置的元数据。

      Expand
      表 5.1. training-job 配置元数据
      参数value

      名称

      training-job 配置的名称

      namespace

      项目的名称

    6. 按照下表所示,编辑 training-job 配置的参数。

      Expand
      表 5.2. training-job 配置参数
      参数value

      model_name_or_path

      预遍历模型或 training-job 容器中模型的路径的名称;在本例中,模型名称取自 Hugging Face 网页

      training_data_path

      您在 training_data.yaml ConfigMap 中设置的培训数据的路径

      output_dir

      模型的输出目录

      save_model_dir

      保存 tuned 模型的目录

      num_train_epochs

      培训的时期数;在本示例中,培训工作设置为运行 10 次

      per_device_train_batch_size

      批处理大小,要一起处理的数据集示例;在此示例中,培训作业处理 4 个示例

      per_device_eval_batch_size

      批处理大小、每个 GPU 或 TPU 内核或 CPU 共同处理的示例;在这个示例中,培训作业处理 4 个示例。

      gradient_accumulation_steps

      科学累积步骤数量

      save_strategy

      模型检查点可以保存的频率;默认值为 "epoch " (保存模型检查点每个时),其他可能的值是 "steps" ( 每个培训步骤保存模型检查点)和 "no" ( 不要保存模型检查点)

      save_total_limit

      要保存的模型检查点数;如果 save_strategy 设为 "no" (没有保存模型检查点,则省略)

      learning_rate

      培训学习率

      weight_decay

      要应用的权重 decay

      lr_scheduler_type

      可选:要使用的调度程序类型;默认值为 "linear",其他可能的值是 "cosine" , "cosine _with_restarts", "polynomial", "constant", 和 "constant_with_warmup"

      include_tokens_per_second

      可选:对于培训速度指标,是否计算每个设备每秒令牌数

      response_template

      响应的模板格式化

      dataset_text_field

      培训输出的 dataset 字段,如 training_data.yaml 配置映射中设置

      padding_free

      是否使用技术在单个批处理中处理多个示例,而无需添加浪费计算资源的 padding 令牌;如果使用,此参数必须设置为 ["huggingface"]

      multipack

      是否使用多GPU培训技术来平衡每个设备中处理的令牌数量,以最大程度缩短等待时间;您可以使用不同的值,找到您的培训工作的最佳价值。

      use_flash_attn

      是否使用闪存关注

      peft_method

      调优方法:要完全微调,请省略此参数;对于 LoRA 和 QLoRA,设置为 "lora" ;对于提示调整,设置为 "pt"

      lora_r

      LoRA:选择低等级的 decom 组成

      lora_alpha

      LoRA:扩展低承诺,以控制其对模型适应的影响

      lora_dropout

      LoRA: Dropout rate applied to the LoRA 层,这是一种常规技术以防止过度处理

      bias

      LoRA:是否要适应模型中的双向术语;将bias 设为 "none" 表示,不会适应任何双向术语

      target_modules

      LoRA:要应用 LoRA 的模块名称;包括所有线性层,设置为"all_linear";某些模型的可选参数

      auto_gptq

      QLoRA:使用 AutoGPTQ 设置 4 位 GPTQ-LoRA;使用时,此参数必须设置为 ["triton_v2"]

      torch_dtype

      QLoRA: Tensor datatype; 使用时,此参数必须设置为 float16

      fp16

      QLoRA:是否要使用半精确浮动点格式;当使用时,此参数必须设置为 true

      fused_lora

      QLoRA:是否要将 fused LoRA 培训用于更有效的 LoRA 培训;如果使用,此参数必须设置为 ["auto_gptq", true]

      fast_kernels

      QLoRA:是否使用快速跨熵、rope、rms 丢失内核;如果使用,此参数必须设置为 [true, true, true]

    7. config_trainingjob.yaml 文件中保存您的更改。
    8. 应用配置以创建 training-config 对象:

      $ oc apply -f config_trainingjob.yaml
      Copy to Clipboard Toggle word wrap
  3. 创建培训数据。

    注意

    本简单示例中的培训数据仅用于演示目的,不适用于生产用途。提供培训数据的常见方法是使用持久卷。

    1. 创建名为 training_data.yaml 的 YAML 文件。
    2. 添加以下 ConfigMap 对象定义:

      kind: ConfigMap
      apiVersion: v1
      metadata:
        name: twitter-complaints
        namespace: kfto
      data:
        twitter_complaints.json: |
          [
              {"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"},
              {"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"},
              {"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"},
              {"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"},
              {"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp… https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp… https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"},
              {"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"},
              {"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"},
              {"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year ��","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year ��\n\n### Label: no complaint"}
          ]
      Copy to Clipboard Toggle word wrap
    3. 将示例命名空间值 kfto 替换为项目的名称。
    4. 使用培训数据替换示例培训数据。
    5. training_data.yaml 文件中保存您的更改。
    6. 应用配置来创建培训数据:

      $ oc apply -f training_data.yaml
      Copy to Clipboard Toggle word wrap
  4. 创建持久性卷声明(PVC),如下所示:

    1. 创建名为 trainedmodelpvc.yaml 的 YAML 文件。
    2. 添加以下 PersistentVolumeClaim 对象定义:

      apiVersion: v1
      kind: PersistentVolumeClaim
      metadata:
        name: trained-model
        namespace: kfto
      spec:
        accessModes:
          - ReadWriteOnce
        resources:
          requests:
            storage: 50Gi
      Copy to Clipboard Toggle word wrap
    3. 将示例命名空间值 kfto 替换为项目的名称,并更新其他参数以适合您的环境。要计算 存储 值,请将模型大小乘以 epoch 的数量,然后添加很少的额外作为缓冲区。
    4. trainedmodelpvc.yaml 文件中保存您的更改。
    5. 应用配置来为培训作业创建持久性卷声明(PVC):

      $ oc apply -f trainedmodelpvc.yaml
      Copy to Clipboard Toggle word wrap

验证

  1. 在 OpenShift 控制台中,从 Project 列表中选择您的项目。
  2. ConfigMaps 并验证是否列出了 training-configtwitter-complaints ConfigMap。
  3. Search。在 Resources 列表中,选择 PersistentVolumeClaim 并验证是否列出了 trained-model PVC。
返回顶部
Red Hat logoGithubredditYoutubeTwitter

学习

尝试、购买和销售

社区

关于红帽文档

通过我们的产品和服务,以及可以信赖的内容,帮助红帽用户创新并实现他们的目标。 了解我们当前的更新.

让开源更具包容性

红帽致力于替换我们的代码、文档和 Web 属性中存在问题的语言。欲了解更多详情,请参阅红帽博客.

關於紅帽

我们提供强化的解决方案,使企业能够更轻松地跨平台和环境(从核心数据中心到网络边缘)工作。

Theme

© 2025 Red Hat