burtenshaw HF Staff commited on
Commit
09817db
·
verified ·
1 Parent(s): 8f7db32

Upload sft_nemo3_native.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sft_nemo3_native.py +79 -0
sft_nemo3_native.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch",
5
+ # "transformers>=4.57.0",
6
+ # "trl>=0.12.0",
7
+ # "datasets",
8
+ # "peft>=0.7.0",
9
+ # "accelerate",
10
+ # "bitsandbytes",
11
+ # "sentencepiece",
12
+ # "protobuf",
13
+ # "trackio",
14
+ # ]
15
+ # ///
16
+ from __future__ import annotations
17
+ import os
18
+ from typing import Any, Dict
19
+ os.environ.setdefault("HF_HOME", "./.hf_home")
20
+ import torch
21
+ from datasets import load_dataset
22
+ from peft import LoraConfig
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
+ from trl import SFTConfig, SFTTrainer
25
+
26
+ MODEL_ID = os.environ.get("MODEL_ID", "unsloth/Nemotron-3-Nano-30B-A3B")
27
+ DATASET_NAME = os.environ.get("DATASET_NAME", "HuggingFaceH4/Multilingual-Thinking")
28
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "burtenshaw/nemotron3-nano-multilingual-thinking")
29
+ MAX_STEPS = int(os.environ.get("MAX_STEPS", "100"))
30
+
31
+ def merge_thinking(example):
32
+ new_msgs = []
33
+ for msg in example["messages"]:
34
+ m = dict(msg)
35
+ content = m.get("content", "")
36
+ thinking = m.pop("thinking", None)
37
+ if thinking and isinstance(thinking, str) and thinking.strip():
38
+ content = f"<think>\n{thinking}\n</think>\n{content}"
39
+ m["content"] = content
40
+ new_msgs.append(m)
41
+ return {**example, "messages": new_msgs}
42
+
43
+ def main():
44
+ print(f"[config] model={MODEL_ID} dataset={DATASET_NAME} hub={HUB_MODEL_ID} steps={MAX_STEPS}")
45
+ torch.manual_seed(42)
46
+ if torch.cuda.is_available():
47
+ torch.cuda.manual_seed_all(42)
48
+ print(f"[cuda] {torch.cuda.get_device_name(0)}")
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
51
+ if tokenizer.pad_token is None:
52
+ tokenizer.pad_token = tokenizer.eos_token
53
+
54
+ print("[loading] model...")
55
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16, use_cache=False, trust_remote_code=True, device_map="auto", low_cpu_mem_usage=True)
56
+
57
+ peft_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
58
+
59
+ print("[loading] dataset...")
60
+ dataset = load_dataset(DATASET_NAME, split="train")
61
+ drop = [c for c in ["reasoning_language","developer","user","analysis","final"] if c in dataset.column_names]
62
+ if drop: dataset = dataset.remove_columns(drop)
63
+ dataset = dataset.map(merge_thinking)
64
+
65
+ def fmt(ex):
66
+ return {"text": [tokenizer.apply_chat_template(c, tokenize=False, add_generation_prompt=False) for c in ex["messages"]]}
67
+ dataset = dataset.map(fmt, batched=True, remove_columns=[c for c in dataset.column_names if c != "text"])
68
+ print(f"[info] {len(dataset)} examples")
69
+
70
+ args = SFTConfig(per_device_train_batch_size=1, gradient_accumulation_steps=8, warmup_steps=10, max_steps=MAX_STEPS, learning_rate=2e-4, optim="paged_adamw_8bit", logging_steps=1, save_steps=25, output_dir="nemotron3-sft", max_length=2048, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, bf16=True, push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="every_save", report_to="trackio", run_name="nemotron3-multilingual-thinking")
71
+
72
+ trainer = SFTTrainer(model=model, args=args, train_dataset=dataset, peft_config=peft_config, processing_class=tokenizer)
73
+ print("[training]...")
74
+ trainer.train()
75
+ trainer.push_to_hub()
76
+ print(f"[done] https://huggingface.co/{HUB_MODEL_ID}")
77
+
78
+ if __name__ == "__main__":
79
+ main()