Text Generation
Transformers
Safetensors
English
Korean
hrm_text
terminal
sft
vllm
tb2-lite
evaluation-pending
Instructions to use LLM-OS-Models/KoHRM-Text-1.4B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use LLM-OS-Models/KoHRM-Text-1.4B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="LLM-OS-Models/KoHRM-Text-1.4B")# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("LLM-OS-Models/KoHRM-Text-1.4B") model = AutoModelForCausalLM.from_pretrained("LLM-OS-Models/KoHRM-Text-1.4B") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use LLM-OS-Models/KoHRM-Text-1.4B with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "LLM-OS-Models/KoHRM-Text-1.4B" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "LLM-OS-Models/KoHRM-Text-1.4B", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/LLM-OS-Models/KoHRM-Text-1.4B
- SGLang
How to use LLM-OS-Models/KoHRM-Text-1.4B with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "LLM-OS-Models/KoHRM-Text-1.4B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "LLM-OS-Models/KoHRM-Text-1.4B", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "LLM-OS-Models/KoHRM-Text-1.4B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "LLM-OS-Models/KoHRM-Text-1.4B", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use LLM-OS-Models/KoHRM-Text-1.4B with Docker Model Runner:
docker model run hf.co/LLM-OS-Models/KoHRM-Text-1.4B
| """Minimal KoHRM-Text generation runtime for Colab. | |
| This file intentionally avoids `transformers` and FlashAttention. It loads the | |
| public `model.safetensors` export and runs HRM-Text generation with PyTorch | |
| scaled-dot-product attention. It is built for long pretraining-checkpoint | |
| knowledge probes on Colab T4 and small CUDA machines. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import math | |
| import argparse | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from safetensors.torch import load_file | |
| from tokenizers import Tokenizer | |
| DEFAULT_CONDITION_TOKENS = { | |
| "direct": "<|object_ref_start|>", | |
| "cot": "<|object_ref_end|>", | |
| "noisy": "<|quad_start|>", | |
| "synth": "<|quad_end|>", | |
| } | |
| def _rms_norm(x: torch.Tensor, eps: float) -> torch.Tensor: | |
| return F.rms_norm(x, (x.shape[-1],), eps=eps) | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def _rope_cos_sin(position_ids: torch.Tensor, head_dim: int, theta: float, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: | |
| inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=position_ids.device, dtype=torch.float32) / head_dim)) | |
| freqs = torch.einsum("bt,d->btd", position_ids.to(torch.float32), inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| return emb.cos().to(dtype), emb.sin().to(dtype) | |
| def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: | |
| return ((x * cos.unsqueeze(-2)) + (_rotate_half(x) * sin.unsqueeze(-2))).to(x.dtype) | |
| class KoHRMAttention(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int, head_dim: int, device: str = "meta") -> None: | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.gqkv_proj = nn.Linear(hidden_size, (4 * num_heads) * head_dim, bias=False, device=device) | |
| self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False, device=device) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| cache: dict[str, torch.Tensor] | None, | |
| cache_pos: int, | |
| ) -> torch.Tensor: | |
| bsz, seqlen, _ = x.shape | |
| gqkv = self.gqkv_proj(x).view(bsz, seqlen, 4 * self.num_heads, self.head_dim) | |
| gate, q, k, v = gqkv.split((self.num_heads, self.num_heads, self.num_heads, self.num_heads), dim=-2) | |
| q = _apply_rope(q, cos, sin) | |
| k = _apply_rope(k, cos, sin) | |
| if cache is not None: | |
| end = cache_pos + seqlen | |
| cache["k"][:, cache_pos:end].copy_(k) | |
| cache["v"][:, cache_pos:end].copy_(v) | |
| k = cache["k"][:, :end] | |
| v = cache["v"][:, :end] | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) | |
| y = y.transpose(1, 2) | |
| y = (torch.sigmoid(gate) * y).reshape(bsz, seqlen, self.num_heads * self.head_dim) | |
| return self.o_proj(y) | |
| class KoHRMMLP(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int, device: str = "meta") -> None: | |
| super().__init__() | |
| self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False, device=device) | |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, device=device) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| gate, up = self.gate_up_proj(x).chunk(2, dim=-1) | |
| return self.down_proj(F.silu(gate) * up) | |
| class KoHRMBlock(nn.Module): | |
| def __init__(self, cfg: dict[str, Any], device: str = "meta") -> None: | |
| super().__init__() | |
| self.eps = float(cfg["rms_norm_eps"]) | |
| self.attn = KoHRMAttention(cfg["hidden_size"], cfg["num_attention_heads"], cfg["head_dim"], device=device) | |
| self.mlp = KoHRMMLP(cfg["hidden_size"], cfg["intermediate_size"], device=device) | |
| def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, cache: dict[str, torch.Tensor] | None, cache_pos: int) -> torch.Tensor: | |
| x = x + self.attn(_rms_norm(x, self.eps), cos, sin, cache, cache_pos) | |
| x = x + self.mlp(_rms_norm(x, self.eps)) | |
| return x | |
| class KoHRMModule(nn.Module): | |
| def __init__(self, cfg: dict[str, Any], num_layers: int, device: str = "meta") -> None: | |
| super().__init__() | |
| self.eps = float(cfg["rms_norm_eps"]) | |
| self.layers = nn.ModuleList([KoHRMBlock(cfg, device=device) for _ in range(num_layers)]) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| input_injection: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| caches: list[dict[str, torch.Tensor]] | None, | |
| cache_pos: int, | |
| ) -> torch.Tensor: | |
| x = hidden_states + input_injection | |
| for idx, layer in enumerate(self.layers): | |
| x = layer(x, cos, sin, None if caches is None else caches[idx], cache_pos) | |
| return _rms_norm(x, self.eps) | |
| class KoHRMCore(nn.Module): | |
| def __init__(self, cfg: dict[str, Any], num_layers: int, device: str = "meta") -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.embedding_scale = float(cfg.get("embedding_scale", 1.0)) | |
| self.embed_tokens = nn.Embedding(cfg["vocab_size"], cfg["hidden_size"], device=device) | |
| self.register_buffer("z_L_init", torch.empty(cfg["hidden_size"], device=device), persistent=True) | |
| self.H_module = KoHRMModule(cfg, num_layers, device=device) | |
| self.L_module = KoHRMModule(cfg, num_layers, device=device) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| caches: dict[str, list[list[dict[str, torch.Tensor]]]] | None, | |
| cache_pos: int, | |
| ) -> torch.Tensor: | |
| x = self.embedding_scale * self.embed_tokens(input_ids) | |
| cos, sin = _rope_cos_sin(position_ids, self.cfg["head_dim"], float(self.cfg["rope_theta"]), x.dtype) | |
| z_h = x | |
| z_l = self.z_L_init.to(dtype=x.dtype).view(1, 1, -1).expand_as(x) | |
| h_cycles, l_cycles = int(self.cfg["H_cycles"]), int(self.cfg["L_cycles"]) | |
| for h_idx in range(h_cycles): | |
| for l_idx in range(l_cycles): | |
| pass_idx = h_idx * l_cycles + l_idx | |
| z_l = self.L_module(z_l, z_h, cos, sin, None if caches is None else caches["L"][pass_idx], cache_pos) | |
| z_h = self.H_module(z_h, z_l, cos, sin, None if caches is None else caches["H"][h_idx], cache_pos) | |
| return z_h | |
| class KoHRMTextForGeneration(nn.Module): | |
| def __init__(self, cfg: dict[str, Any], num_layers: int, device: str = "meta") -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.num_layers = num_layers | |
| self.model = KoHRMCore(cfg, num_layers, device=device) | |
| self.lm_head = nn.Linear(cfg["hidden_size"], cfg["vocab_size"], bias=False, device=device) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| caches: dict[str, list[list[dict[str, torch.Tensor]]]] | None = None, | |
| cache_pos: int = 0, | |
| ) -> torch.Tensor: | |
| hidden = self.model(input_ids, position_ids, caches, cache_pos) | |
| return self.lm_head(hidden) | |
| def init_cache(self, batch_size: int, max_seq_len: int, device: torch.device, dtype: torch.dtype) -> dict[str, list[list[dict[str, torch.Tensor]]]]: | |
| heads, head_dim = int(self.cfg["num_attention_heads"]), int(self.cfg["head_dim"]) | |
| def one_layer() -> dict[str, torch.Tensor]: | |
| shape = (batch_size, max_seq_len, heads, head_dim) | |
| return { | |
| "k": torch.empty(shape, device=device, dtype=dtype), | |
| "v": torch.empty(shape, device=device, dtype=dtype), | |
| } | |
| def one_pass() -> list[dict[str, torch.Tensor]]: | |
| return [one_layer() for _ in range(self.num_layers)] | |
| return { | |
| "H": [one_pass() for _ in range(int(self.cfg["H_cycles"]))], | |
| "L": [one_pass() for _ in range(int(self.cfg["H_cycles"]) * int(self.cfg["L_cycles"]))], | |
| } | |
| def _module_layer_count(state: dict[str, torch.Tensor], prefix: str) -> int: | |
| layers = set() | |
| marker = f"{prefix}.layers." | |
| for key in state: | |
| if key.startswith(marker): | |
| layers.add(int(key[len(marker) :].split(".", 1)[0])) | |
| return max(layers) + 1 | |
| def load_kohrm(repo_dir: str | Path, device: str | None = None, max_gpu_memory_gib: float | None = None) -> tuple[KoHRMTextForGeneration, Tokenizer, dict[str, Any]]: | |
| repo_dir = Path(repo_dir) | |
| cfg = json.loads((repo_dir / "config.json").read_text()) | |
| tokenizer = Tokenizer.from_file(str(repo_dir / "tokenizer.json")) | |
| state = load_file(str(repo_dir / "model.safetensors"), device="cpu") | |
| num_layers = _module_layer_count(state, "model.H_module") | |
| model = KoHRMTextForGeneration(cfg, num_layers=num_layers, device="meta") | |
| model.load_state_dict(state, strict=True, assign=True) | |
| del state | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| target = torch.device(device) | |
| dtype = torch.float16 if target.type == "cuda" else torch.float32 | |
| model = model.to(device=target, dtype=dtype).eval() | |
| if target.type == "cuda": | |
| torch.set_float32_matmul_precision("high") | |
| if target.type == "cuda" and max_gpu_memory_gib is not None: | |
| free, total = torch.cuda.mem_get_info() | |
| print(f"GPU memory free/total GiB: {free / 2**30:.2f}/{total / 2**30:.2f}") | |
| return model, tokenizer, cfg | |
| def condition_to_tokens(condition: str = "direct", mapping: dict[str, str] | None = None) -> str: | |
| """Map upstream HRM-Text condition names to tokenizer control tokens.""" | |
| mapping = mapping or DEFAULT_CONDITION_TOKENS | |
| pieces: list[str] = [] | |
| for raw_name in condition.split(","): | |
| name = raw_name.strip() | |
| if not name: | |
| continue | |
| if name not in mapping: | |
| valid = ", ".join(sorted(mapping)) | |
| raise ValueError(f"Unknown condition {name!r}; expected one of: {valid}") | |
| pieces.append(mapping[name]) | |
| if not pieces: | |
| pieces.append(mapping["direct"]) | |
| return "".join(pieces) | |
| def format_kohrm_prompt( | |
| prompt: str, | |
| condition: str = "direct", | |
| condition_token: str | None = None, | |
| ) -> str: | |
| """Format prompts like upstream InferenceCheckpoint.tokenize_prompt(). | |
| Upstream wraps prompts as: | |
| `<boq><condition_tokens><instruction><eoq>`. | |
| For answer-only generation use condition="direct", which maps to | |
| `<|object_ref_start|>` in the KoHRM tokenizer. `condition_token` is kept | |
| for backward compatibility and overrides `condition` when supplied. | |
| """ | |
| if condition_token is None: | |
| condition_token = condition_to_tokens(condition) | |
| return f"<|im_start|>{condition_token}{prompt}<|im_end|>" | |
| def _apply_repetition_penalty(logits: torch.Tensor, seen_ids: list[int], penalty: float) -> torch.Tensor: | |
| if penalty <= 1.0 or not seen_ids: | |
| return logits | |
| for token_id in set(seen_ids): | |
| value = logits[..., token_id] | |
| logits[..., token_id] = torch.where(value < 0, value * penalty, value / penalty) | |
| return logits | |
| def _apply_no_repeat_ngram(logits: torch.Tensor, seen_ids: list[int], ngram_size: int) -> torch.Tensor: | |
| if ngram_size <= 0 or len(seen_ids) < ngram_size - 1: | |
| return logits | |
| prefix = tuple(seen_ids[-(ngram_size - 1):]) | |
| blocked: set[int] = set() | |
| for idx in range(len(seen_ids) - ngram_size + 1): | |
| if tuple(seen_ids[idx:idx + ngram_size - 1]) == prefix: | |
| blocked.add(seen_ids[idx + ngram_size - 1]) | |
| if blocked: | |
| logits[..., list(blocked)] = -torch.inf | |
| return logits | |
| def _sample_next( | |
| logits: torch.Tensor, | |
| temperature: float, | |
| top_p: float, | |
| seen_ids: list[int] | None = None, | |
| repetition_penalty: float = 1.0, | |
| no_repeat_ngram_size: int = 0, | |
| blocked_ids: set[int] | None = None, | |
| ) -> int: | |
| logits = logits.float() | |
| seen_ids = seen_ids or [] | |
| logits = _apply_repetition_penalty(logits, seen_ids, repetition_penalty) | |
| logits = _apply_no_repeat_ngram(logits, seen_ids, no_repeat_ngram_size) | |
| if blocked_ids: | |
| logits[..., list(blocked_ids)] = -torch.inf | |
| if temperature <= 0: | |
| return int(torch.argmax(logits, dim=-1).item()) | |
| probs = torch.softmax(logits / temperature, dim=-1) | |
| if top_p < 1.0: | |
| sorted_probs, sorted_idx = torch.sort(probs, descending=True) | |
| keep = torch.cumsum(sorted_probs, dim=-1) <= top_p | |
| keep[..., 0] = True | |
| sorted_probs = sorted_probs.masked_fill(~keep, 0) | |
| sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) | |
| next_sorted = torch.multinomial(sorted_probs, num_samples=1) | |
| return int(sorted_idx.gather(-1, next_sorted).item()) | |
| return int(torch.multinomial(probs, num_samples=1).item()) | |
| def generate_from_loaded( | |
| model: KoHRMTextForGeneration, | |
| tokenizer: Tokenizer, | |
| cfg: dict[str, Any], | |
| prompt: str, | |
| *, | |
| max_new_tokens: int = 64, | |
| min_new_tokens: int = 0, | |
| max_seq_len: int = 512, | |
| temperature: float = 0.0, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.18, | |
| no_repeat_ngram_size: int = 4, | |
| condition: str = "direct", | |
| condition_token: str | None = None, | |
| ) -> str: | |
| dev = next(model.parameters()).device | |
| dtype = next(model.parameters()).dtype | |
| wrapped = format_kohrm_prompt(prompt, condition=condition, condition_token=condition_token) | |
| input_ids = tokenizer.encode(wrapped, add_special_tokens=False).ids | |
| if len(input_ids) + max_new_tokens + 1 > max_seq_len: | |
| raise ValueError(f"Prompt plus generation exceeds max_seq_len={max_seq_len}: prompt_tokens={len(input_ids)}") | |
| caches = model.init_cache(1, max_seq_len, dev, dtype) | |
| ids = torch.tensor([input_ids], device=dev, dtype=torch.long) | |
| pos = torch.arange(ids.shape[1], device=dev, dtype=torch.long).unsqueeze(0) | |
| logits = model(ids, pos, caches=caches, cache_pos=0)[:, -1, :] | |
| cache_pos = ids.shape[1] | |
| eos_id = int(cfg.get("eos_token_id") or tokenizer.token_to_id("<|box_end|>")) | |
| stop_ids = { | |
| eos_id, | |
| tokenizer.token_to_id("<|im_end|>"), | |
| tokenizer.token_to_id("<|box_end|>"), | |
| } | |
| stop_ids = {int(x) for x in stop_ids if x is not None} | |
| out_ids: list[int] = [] | |
| seen_ids = list(input_ids) | |
| next_id = _sample_next( | |
| logits, | |
| temperature, | |
| top_p, | |
| seen_ids, | |
| repetition_penalty, | |
| no_repeat_ngram_size, | |
| blocked_ids=stop_ids if min_new_tokens > 0 else None, | |
| ) | |
| for _ in range(max_new_tokens): | |
| if next_id in stop_ids and len(out_ids) >= min_new_tokens: | |
| break | |
| out_ids.append(next_id) | |
| seen_ids.append(next_id) | |
| token = torch.tensor([[next_id]], device=dev, dtype=torch.long) | |
| pos = torch.tensor([[cache_pos]], device=dev, dtype=torch.long) | |
| logits = model(token, pos, caches=caches, cache_pos=cache_pos)[:, -1, :] | |
| cache_pos += 1 | |
| next_id = _sample_next( | |
| logits, | |
| temperature, | |
| top_p, | |
| seen_ids, | |
| repetition_penalty, | |
| no_repeat_ngram_size, | |
| blocked_ids=stop_ids if len(out_ids) < min_new_tokens else None, | |
| ) | |
| return tokenizer.decode(out_ids, skip_special_tokens=True).strip() | |
| def generate_text( | |
| repo_dir: str | Path, | |
| prompt: str, | |
| *, | |
| max_new_tokens: int = 64, | |
| min_new_tokens: int = 0, | |
| max_seq_len: int = 512, | |
| temperature: float = 0.0, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.18, | |
| no_repeat_ngram_size: int = 4, | |
| condition: str = "direct", | |
| condition_token: str | None = None, | |
| device: str | None = None, | |
| ) -> str: | |
| model, tokenizer, cfg = load_kohrm(repo_dir, device=device, max_gpu_memory_gib=14.0) | |
| return generate_from_loaded( | |
| model, | |
| tokenizer, | |
| cfg, | |
| prompt, | |
| max_new_tokens=max_new_tokens, | |
| min_new_tokens=min_new_tokens, | |
| max_seq_len=max_seq_len, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| condition=condition, | |
| condition_token=condition_token, | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run a KoHRM-Text long generation probe without transformers.") | |
| parser.add_argument("repo_dir", type=Path, help="Directory containing config.json, tokenizer.json, and model.safetensors") | |
| parser.add_argument( | |
| "--prompt", | |
| default=( | |
| "๋ค์์ ํ๊ตญ์ด ์ํค๋ฐฑ๊ณผ ๋ฌธ์ ์๋ฌธ ์ผ๋ถ์ ๋๋ค. ๋ฐฑ๊ณผ์ฌ์ ์ ํ๊ตญ์ด, " | |
| "๊ณ ์ ๋ช ์ฌ, ๋ ์ง, ๊ธฐ์ /์ฌํ/๋ฌธํ ์ง์์ ๊ทธ๋๋ก ํ์ตํ์ญ์์ค.\n\n" | |
| "[๋ฌธ์๋ช ]\nํ๋ฏผ์ ์\n\n[๋ถ๋ถ]\n1/1" | |
| ), | |
| ) | |
| parser.add_argument("--max-new-tokens", type=int, default=384) | |
| parser.add_argument("--min-new-tokens", type=int, default=160) | |
| parser.add_argument("--max-seq-len", type=int, default=1536) | |
| parser.add_argument("--temperature", type=float, default=0.65) | |
| parser.add_argument("--top-p", type=float, default=0.92) | |
| parser.add_argument("--repetition-penalty", type=float, default=1.05) | |
| parser.add_argument("--no-repeat-ngram-size", type=int, default=0) | |
| parser.add_argument( | |
| "--condition", | |
| default="direct", | |
| help="Comma-separated HRM-Text condition names: direct, cot, noisy, synth. Use direct for answer-only outputs.", | |
| ) | |
| parser.add_argument( | |
| "--condition-token", | |
| default=None, | |
| help="Optional raw condition token override. Normally use --condition direct instead.", | |
| ) | |
| parser.add_argument("--device", default=None) | |
| args = parser.parse_args() | |
| print(generate_text( | |
| args.repo_dir, | |
| args.prompt, | |
| max_new_tokens=args.max_new_tokens, | |
| min_new_tokens=args.min_new_tokens, | |
| max_seq_len=args.max_seq_len, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| repetition_penalty=args.repetition_penalty, | |
| no_repeat_ngram_size=args.no_repeat_ngram_size, | |
| condition=args.condition, | |
| condition_token=args.condition_token, | |
| device=args.device, | |
| )) | |
| if __name__ == "__main__": | |
| main() | |