| |
| from __future__ import annotations |
|
|
| import os |
| from typing import Any, Dict, Union |
|
|
| import torch |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
| MAX_INPUT_TOKENS = 512 |
| DEFAULT_MAX_NEW_TOKENS = 128 |
|
|
| DEFAULT_SYSTEM_PROMPT = ( |
| "You are Teapot, an open-source AI assistant optimized for low-end devices, " |
| "providing short, accurate responses without hallucinating while excelling at " |
| "information extraction and text summarization. " |
| "If the context does not answer the question, reply exactly: " |
| "'I am sorry but I don't have any information on that'." |
| ) |
|
|
|
|
| def _path_exists(p: str) -> bool: |
| try: |
| return os.path.exists(p) |
| except Exception: |
| return False |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| |
| spiece_path = os.path.join(path, "spiece.model") |
| tokjson_path = os.path.join(path, "tokenizer.json") |
| cfg_path = os.path.join(path, "config.json") |
|
|
| print(f"[teapot] model_dir={path}") |
| print(f"[teapot] exists config.json={_path_exists(cfg_path)} tokenizer.json={_path_exists(tokjson_path)} spiece.model={_path_exists(spiece_path)}") |
|
|
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| path, |
| use_fast=False, |
| model_max_length=MAX_INPUT_TOKENS, |
| ) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(path) |
|
|
| self.device = torch.device("cpu") |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| |
| |
| |
| tok_len = len(self.tokenizer) |
| tok_vocab_size = getattr(self.tokenizer, "vocab_size", None) |
| cfg_vocab = getattr(self.model.config, "vocab_size", None) |
| emb_rows = int(self.model.get_input_embeddings().weight.shape[0]) |
|
|
| print(f"[teapot] tokenizer_class={type(self.tokenizer).__name__} use_fast={getattr(self.tokenizer, 'is_fast', None)}") |
| print(f"[teapot] len(tokenizer)={tok_len} tokenizer.vocab_size={tok_vocab_size} model.config.vocab_size={cfg_vocab} embedding_rows={emb_rows}") |
| print(f"[teapot] special_tokens: pad={self.tokenizer.pad_token} eos={self.tokenizer.eos_token} unk={self.tokenizer.unk_token}") |
|
|
| |
| |
| |
| if emb_rows != tok_len: |
| raise RuntimeError( |
| f"[teapot] FATAL: embedding_rows ({emb_rows}) != len(tokenizer) ({tok_len}). " |
| "This means your model weights and tokenizer files are out of sync in the repo. " |
| "Fix by re-saving model+tokenizer together after resize_token_embeddings." |
| ) |
| if cfg_vocab is not None and cfg_vocab != emb_rows: |
| raise RuntimeError( |
| f"[teapot] FATAL: model.config.vocab_size ({cfg_vocab}) != embedding_rows ({emb_rows}). " |
| "Your config.json is inconsistent with the weights. Re-save model to update config." |
| ) |
|
|
| self.system_prompt = DEFAULT_SYSTEM_PROMPT |
|
|
| @torch.inference_mode() |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
| if not isinstance(data, dict) or "inputs" not in data: |
| raise ValueError("Request must be JSON with an 'inputs' field.") |
|
|
| inputs: Union[str, Dict[str, Any]] = data["inputs"] |
| params = data.get("parameters") or {} |
|
|
| max_new_tokens = int(params.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)) |
|
|
| if isinstance(inputs, str): |
| prompt = inputs |
| elif isinstance(inputs, dict): |
| context = inputs.get("context", "") |
| question = inputs.get("question", "") |
| system_prompt = inputs.get("system_prompt", self.system_prompt) |
| prompt = f"{context}\n{system_prompt}\n{question}\n" |
| else: |
| raise ValueError("'inputs' must be a string or an object with {context, question}.") |
|
|
| enc = self.tokenizer(prompt, return_tensors="pt") |
| input_ids = enc["input_ids"] |
| attention_mask = enc.get("attention_mask") |
|
|
| |
| if input_ids.shape[1] > MAX_INPUT_TOKENS: |
| input_ids = input_ids[:, -MAX_INPUT_TOKENS:] |
| if attention_mask is not None: |
| attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:] |
|
|
| input_ids = input_ids.to(self.device) |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(self.device) |
|
|
| out = self.model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| do_sample=False, |
| num_beams=1, |
| max_new_tokens=max_new_tokens, |
| |
| repetition_penalty=1.05, |
| no_repeat_ngram_size=3, |
| ) |
|
|
| text = self.tokenizer.decode(out[0], skip_special_tokens=True) |
| return {"generated_text": text} |