使用分布式训练微调 OpenAI 的 GPT-OSS 120B 模型

此笔记本演示使用 Databricks 无服务器 GPU 计算在 8 H100 GPU 上对大型 120B 参数 GPT-OSS 模型进行监督微调(SFT)。 训练课程中利用了:

  • FSDP(完全分片数据并行):将模型参数、梯度和优化器状态分片到多个 GPU 上,以训练那些无法在单个 GPU 上运行的大型模型。
  • DDP (分布式数据并行):跨多个 GPU 分配训练,以加快训练速度。
  • LoRA (低秩适应):通过添加小型适配器层来减少可训练参数的数量,从而使微调更高效。
  • 变压器强化学习(TRL):提供 SFTTrainer 用于进行监督微调。

通过设置 remote=False 和指定 16 个 GPU,这可以扩展到跨 16 个 GPU 的多节点训练。

安装所需程序包

安装用于分布式训练和模型微调的必要库:

  • trl:用于 SFT 训练的转换器强化学习库
  • peft:参数高效微调的 LoRA 适配器
  • transformers:Hugging Face Transformers库
  • datasets:用于加载训练数据集
  • accelerate:用于分布式训练协调
  • hf_transfer:提高从 Hugging Face 下载模型的速度
%pip install "trl==1.1.0"
%pip install "peft==0.19.1"
%pip install "transformers==5.5.4"
%pip install "fsspec==2024.9.0"
%pip install "huggingface_hub==1.11.0"
%pip install "datasets==3.2.0"
%pip install "accelerate==1.13.0"
%restart_python

使用 FSDP 定义分布式训练函数

此单元格定义训练函数,该函数将使用修饰器在 8 个 H100 GPU @distributed 上运行。 该函数包括:

  • 模型加载:以 bfloat16 精度加载 120B 参数 GPT-OSS 模型
  • LoRA 配置:应用具有秩为 16 的低秩适应方法,以减少可训练参数的数量。
  • FSDP 设置:配置全分片数据并行处理,并自动进行层包装和激活检查点设定
  • 训练配置:设置批大小、学习速率、渐变累积和其他超参数
  • 数据集:使用 HuggingFaceH4/Multilingual-Thinking 数据集进行微调

该函数自动检测 FSDP 包装的转换器块类,并处理所有 GPU 之间的分布式训练协调。

dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "gpt-oss-120b-peft")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-120b")
dbutils.widgets.text("dataset_path", "HuggingFaceH4/Multilingual-Thinking")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
HF_MODEL_NAME = dbutils.widgets.get("model")
DATASET_PATH = dbutils.widgets.get("dataset_path")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"HF_MODEL_NAME: {HF_MODEL_NAME}")
print(f"DATASET_PATH: {DATASET_PATH}")

OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
from serverless_gpu import distributed

@distributed(gpus=8, gpu_type='H100')
def train_gpt_oss_fsdp_120b():
    """
    Fine-tune a 120B-class model with TRL SFTTrainer + FSDP2 on H100s.
    Uses LoRA + activation ckpt + full_shard auto_wrap.
    """

    # --- imports inside for pickle safety ---
    import os, torch, torch.distributed as dist
    from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config
    from trl import SFTTrainer, SFTConfig
    from datasets import load_dataset
    from peft import LoraConfig, get_peft_model

    # ---------- DDP / CUDA binding ----------
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
    os.environ.setdefault("NCCL_DEBUG", "WARN")
    os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")

    os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")  # replaces NCCL_ASYNC_ERROR_HANDLING

    # ---------- Config ----------
    MAX_LENGTH = 2048
    PER_DEVICE_BATCH = 1                 # start conservative for 120B
    GRAD_ACCUM = 4                       # tune for throughput
    LR = 1.5e-4
    EPOCHS = 1

    is_main  = int(os.environ.get("RANK", "0")) == 0
    world_size = int(os.environ.get("WORLD_SIZE", "1"))

    if is_main:
        print("=" * 60)
        print("FSDP (full_shard) launch for 120B")
        print(f"WORLD_SIZE={world_size} | LOCAL_RANK={local_rank}")
        print("=" * 60)

    # ---------- Tokenizer ----------
    tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = MAX_LENGTH
    tokenizer.truncation_side = "right"

    # ---------- Model ----------
    # IMPORTANT: no device_map, no .to(device) — let Trainer/Accelerate+FSDP handle placement
    # low_cpu_mem_usage helps with massive checkpoints (still needs decent host RAM)
    quantization_config = Mxfp4Config(dequantize=True)
    model = AutoModelForCausalLM.from_pretrained(
        HF_MODEL_NAME,
        dtype=torch.bfloat16,
        attn_implementation="eager",
        quantization_config=quantization_config,
        use_cache=False,                  # needed for grad ckpt
        low_cpu_mem_usage=True,
    )

    # ---------- LoRA ----------
    # the following config works
    # include MoE layers as well.
    peft_config = LoraConfig(
        r=32,
        lora_alpha=32,
        target_modules="all-linear",
        rank_pattern={
            "mlp.experts.gate_up_proj": 8,
            "mlp.experts.down_proj": 8
        },
        target_parameters=["mlp.experts.gate_up_proj", "mlp.experts.down_proj"],
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, peft_config)

    # Cast all parameters to bfloat16 so FSDP sees a uniform dtype
    # (LoRA adapters are initialized in float32 by default)
    model = model.to(torch.bfloat16)

    if is_main:
        model.print_trainable_parameters()

    # ---------- Data ----------
    dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
    if is_main:
        print(f"Dataset size: {len(dataset)}")

    # ---------- FSDP settings ----------
    def infer_transformer_blocks_for_fsdp(model):
        COMMON = {
            "LlamaDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer",
            "Qwen2DecoderLayer", "Gemma2DecoderLayer", "Phi3DecoderLayer",
            "GPTNeoXLayer", "MPTBlock", "BloomBlock", "FalconDecoderLayer",
            "DecoderLayer", "GPTJBlock", "OPTDecoderLayer"
        }
        hits = set()
        for _, m in model.named_modules():
            name = m.__class__.__name__
            if name in COMMON:
                hits.add(name)
        # Fallback: grab anything that *looks* like a decoder block
        if not hits:
            for _, m in model.named_modules():
                name = m.__class__.__name__
                if any(s in name for s in ["Block", "DecoderLayer", "Layer"]) and "Embedding" not in name:
                    hits.add(name)
        return sorted(hits)


    fsdp_wrap_classes = infer_transformer_blocks_for_fsdp(model)
    if not fsdp_wrap_classes:
        raise RuntimeError("Could not infer transformer block classes for FSDP wrapping; "
                       "print(model) and add the block class explicitly.")


    training_args = SFTConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=PER_DEVICE_BATCH,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=LR,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        bf16=True,
        logging_steps=5,
        logging_strategy="steps",
        save_strategy="no",
        report_to="none",
        ddp_find_unused_parameters=False,
        dataloader_pin_memory=True,
        max_length=MAX_LENGTH,
        gradient_checkpointing=False,

        # ---- FSDP2 knobs ----
        fsdp="full_shard auto_wrap",
        fsdp_config={
            "version": 2,
            "fsdp_transformer_layer_cls_to_wrap": fsdp_wrap_classes,
            "reshard_after_forward": True,
            "activation_checkpointing": True,    # <- use activation ckpt (not gradient)
            "xla": False,
            "limit_all_gathers": True,
        },
    )

    # ---------- Trainer ----------
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
    )

    # verify distributed init & FSDP
    rank = int(os.getenv("RANK", "0"))
    print(f"[rank {rank}] dist.is_initialized() -> {dist.is_initialized()}")
    acc = getattr(trainer, "accelerator", None)
    print(f"[rank {rank}] accelerator.distributed_type = {getattr(getattr(acc,'state',None),'distributed_type','n/a')}")
    print(f"[rank {rank}] accelerator.num_processes = {getattr(acc, 'num_processes', 'n/a')}")

    # ---------- Train ----------
    result = trainer.train()

    if is_main:
        print("\nTraining complete (FSDP).")
        print(result.metrics)

运行分布式训练任务

在 8 H100 GPU 上执行训练函数。 @distributed修饰器负责协调在所有 GPU 上启动训练,并保证适当的分布式设置。

train_gpt_oss_fsdp_120b.distributed()

后续步骤

示例笔记本

使用分布式训练微调 OpenAI 的 GPT-OSS 120B 模型

获取笔记本