분산 학습을 사용하여 OpenAI의 GPT-OSS 120B 모델 미세 조정

이 Notebook은 Databricks 서버리스 GPU 컴퓨팅을 사용하여 8개의 H100 GPU에 대한 대형 120B 매개 변수 GPT-OSS 모델의 감독된 SFT(미세 조정)를 보여 줍니다. 학습은 다음을 활용합니다.

  • FSDP(완전 분할된 데이터 병렬): 단일 GPU에 맞지 않는 대형 모델을 학습할 수 있도록 GPU 간에 분할된 모델 매개 변수, 그라데이션 및 최적화 프로그램 상태를 분할합니다.
  • DDP(분산 데이터 병렬): 더 빠른 학습을 위해 여러 GPU에 학습을 분산합니다.
  • LoRA(Low-Rank 적응): 작은 어댑터 레이어를 추가하여 학습 가능한 매개 변수 수를 줄여 미세 조정의 효율성을 높입니다.
  • TRL(변환기 보충 학습): 감독된 미세 조정을 위한 SFTTrainer를 제공합니다.

16개의 GPU를 설정하고 remote=False 지정하면 16개의 GPU에서 다중 노드 학습으로 확장할 수 있습니다.

필수 패키지 설치

분산 학습 및 모델 미세 조정에 필요한 라이브러리를 설치합니다.

  • trl: SFT 학습을 위한 트랜스포머 강화 학습 라이브러리
  • peft: LoRA 어댑터에 대한 매개변수 효율적인 미세 조정
  • transformers: Hugging Face 트랜스포머 라이브러리
  • datasets: 학습 데이터 세트를 로드하는 경우
  • accelerate: 분산 학습 오케스트레이션의 경우
  • hf_transfer: Hugging Face에서 모델을 더 빠르게 다운로드하기 위하여
%pip install "trl==1.1.0"
%pip install "peft==0.19.1"
%pip install "transformers==5.5.4"
%pip install "fsspec==2024.9.0"
%pip install "huggingface_hub==1.11.0"
%pip install "datasets==3.2.0"
%pip install "accelerate==1.13.0"
%restart_python

FSDP를 사용하여 분산 학습 함수 정의

이 셀은 데코레이터를 사용하여 8개의 H100 GPU에서 실행되는 학습 함수를 @distributed 정의합니다. 함수에는 다음이 포함됩니다.

  • 모델 로드: bfloat16 정밀도로 1200억 개의 매개변수를 가진 GPT-OSS 모델을 로드합니다.
  • LoRA 구성: 학습 가능한 매개 변수를 줄이기 위해 순위 16의 Low-Rank 적응을 적용합니다.
  • FSDP 설정: 자동 계층 래핑 및 활성화 검사점을 사용하여 완전히 분할된 데이터 병렬 구성
  • 학습 구성: 일괄 처리 크기, 학습 속도, 그라데이션 누적 및 기타 하이퍼 매개 변수 설정
  • 데이터 세트: HuggingFaceH4/Multilingual-Thinking 데이터 세트를 사용하여 미세 조정

이 함수는 FSDP 래핑에 대한 변환기 블록 클래스를 자동으로 검색하고 모든 GPU에서 분산 학습 조정을 처리합니다.

dbutils.widgets.text("uc_catalog", "main")
dbutils.widgets.text("uc_schema", "default")
dbutils.widgets.text("uc_model_name", "gpt-oss-120b-peft")
dbutils.widgets.text("uc_volume", "checkpoints")
dbutils.widgets.text("model", "openai/gpt-oss-120b")
dbutils.widgets.text("dataset_path", "HuggingFaceH4/Multilingual-Thinking")

UC_CATALOG = dbutils.widgets.get("uc_catalog")
UC_SCHEMA = dbutils.widgets.get("uc_schema")
UC_MODEL_NAME = dbutils.widgets.get("uc_model_name")
UC_VOLUME = dbutils.widgets.get("uc_volume")
HF_MODEL_NAME = dbutils.widgets.get("model")
DATASET_PATH = dbutils.widgets.get("dataset_path")

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"UC_MODEL_NAME: {UC_MODEL_NAME}")
print(f"UC_VOLUME: {UC_VOLUME}")
print(f"HF_MODEL_NAME: {HF_MODEL_NAME}")
print(f"DATASET_PATH: {DATASET_PATH}")

OUTPUT_DIR = f"/Volumes/{UC_CATALOG}/{UC_SCHEMA}/{UC_VOLUME}/{UC_MODEL_NAME}"
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
from serverless_gpu import distributed

@distributed(gpus=8, gpu_type='H100')
def train_gpt_oss_fsdp_120b():
    """
    Fine-tune a 120B-class model with TRL SFTTrainer + FSDP2 on H100s.
    Uses LoRA + activation ckpt + full_shard auto_wrap.
    """

    # --- imports inside for pickle safety ---
    import os, torch, torch.distributed as dist
    from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config
    from trl import SFTTrainer, SFTConfig
    from datasets import load_dataset
    from peft import LoraConfig, get_peft_model

    # ---------- DDP / CUDA binding ----------
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
    os.environ.setdefault("NCCL_DEBUG", "WARN")
    os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")

    os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")  # replaces NCCL_ASYNC_ERROR_HANDLING

    # ---------- Config ----------
    MAX_LENGTH = 2048
    PER_DEVICE_BATCH = 1                 # start conservative for 120B
    GRAD_ACCUM = 4                       # tune for throughput
    LR = 1.5e-4
    EPOCHS = 1

    is_main  = int(os.environ.get("RANK", "0")) == 0
    world_size = int(os.environ.get("WORLD_SIZE", "1"))

    if is_main:
        print("=" * 60)
        print("FSDP (full_shard) launch for 120B")
        print(f"WORLD_SIZE={world_size} | LOCAL_RANK={local_rank}")
        print("=" * 60)

    # ---------- Tokenizer ----------
    tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = MAX_LENGTH
    tokenizer.truncation_side = "right"

    # ---------- Model ----------
    # IMPORTANT: no device_map, no .to(device) — let Trainer/Accelerate+FSDP handle placement
    # low_cpu_mem_usage helps with massive checkpoints (still needs decent host RAM)
    quantization_config = Mxfp4Config(dequantize=True)
    model = AutoModelForCausalLM.from_pretrained(
        HF_MODEL_NAME,
        dtype=torch.bfloat16,
        attn_implementation="eager",
        quantization_config=quantization_config,
        use_cache=False,                  # needed for grad ckpt
        low_cpu_mem_usage=True,
    )

    # ---------- LoRA ----------
    # the following config works
    # include MoE layers as well.
    peft_config = LoraConfig(
        r=32,
        lora_alpha=32,
        target_modules="all-linear",
        rank_pattern={
            "mlp.experts.gate_up_proj": 8,
            "mlp.experts.down_proj": 8
        },
        target_parameters=["mlp.experts.gate_up_proj", "mlp.experts.down_proj"],
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, peft_config)

    # Cast all parameters to bfloat16 so FSDP sees a uniform dtype
    # (LoRA adapters are initialized in float32 by default)
    model = model.to(torch.bfloat16)

    if is_main:
        model.print_trainable_parameters()

    # ---------- Data ----------
    dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
    if is_main:
        print(f"Dataset size: {len(dataset)}")

    # ---------- FSDP settings ----------
    def infer_transformer_blocks_for_fsdp(model):
        COMMON = {
            "LlamaDecoderLayer", "MistralDecoderLayer", "MixtralDecoderLayer",
            "Qwen2DecoderLayer", "Gemma2DecoderLayer", "Phi3DecoderLayer",
            "GPTNeoXLayer", "MPTBlock", "BloomBlock", "FalconDecoderLayer",
            "DecoderLayer", "GPTJBlock", "OPTDecoderLayer"
        }
        hits = set()
        for _, m in model.named_modules():
            name = m.__class__.__name__
            if name in COMMON:
                hits.add(name)
        # Fallback: grab anything that *looks* like a decoder block
        if not hits:
            for _, m in model.named_modules():
                name = m.__class__.__name__
                if any(s in name for s in ["Block", "DecoderLayer", "Layer"]) and "Embedding" not in name:
                    hits.add(name)
        return sorted(hits)


    fsdp_wrap_classes = infer_transformer_blocks_for_fsdp(model)
    if not fsdp_wrap_classes:
        raise RuntimeError("Could not infer transformer block classes for FSDP wrapping; "
                       "print(model) and add the block class explicitly.")


    training_args = SFTConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=PER_DEVICE_BATCH,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=LR,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        bf16=True,
        logging_steps=5,
        logging_strategy="steps",
        save_strategy="no",
        report_to="none",
        ddp_find_unused_parameters=False,
        dataloader_pin_memory=True,
        max_length=MAX_LENGTH,
        gradient_checkpointing=False,

        # ---- FSDP2 knobs ----
        fsdp="full_shard auto_wrap",
        fsdp_config={
            "version": 2,
            "fsdp_transformer_layer_cls_to_wrap": fsdp_wrap_classes,
            "reshard_after_forward": True,
            "activation_checkpointing": True,    # <- use activation ckpt (not gradient)
            "xla": False,
            "limit_all_gathers": True,
        },
    )

    # ---------- Trainer ----------
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
    )

    # verify distributed init & FSDP
    rank = int(os.getenv("RANK", "0"))
    print(f"[rank {rank}] dist.is_initialized() -> {dist.is_initialized()}")
    acc = getattr(trainer, "accelerator", None)
    print(f"[rank {rank}] accelerator.distributed_type = {getattr(getattr(acc,'state',None),'distributed_type','n/a')}")
    print(f"[rank {rank}] accelerator.num_processes = {getattr(acc, 'num_processes', 'n/a')}")

    # ---------- Train ----------
    result = trainer.train()

    if is_main:
        print("\nTraining complete (FSDP).")
        print(result.metrics)

분산 학습 작업 실행

8개의 H100 GPU에서 학습 함수를 실행합니다. 데코레이터는 @distributed 적절한 분산 설정을 사용하여 모든 GPU에서 학습을 시작하는 오케스트레이션을 처리합니다.

train_gpt_oss_fsdp_120b.distributed()

다음 단계

예제 노트

분산 학습을 사용하여 OpenAI의 GPT-OSS 120B 모델 미세 조정

노트북 받기