| | from __future__ import annotations |
| |
|
| | import re |
| | from dataclasses import dataclass |
| | from typing import Sequence |
| |
|
| | import torch |
| | from torch import nn |
| | from transformers.generation.utils import GenerationMixin |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils.generic import ModelOutput |
| |
|
| | from .config import CircuitGPTConfig |
| | from .gpt import GPT |
| | from .hook_utils import hook_recorder |
| |
|
| |
|
| | @dataclass |
| | class CircuitGPTCausalLMOutput(ModelOutput): |
| | loss: torch.Tensor | None = None |
| | logits: torch.Tensor | None = None |
| | activations: dict[str, torch.Tensor] | None = None |
| |
|
| |
|
| | def _activations_regex(keys: Sequence[str]) -> str: |
| | escaped = (re.escape(k) for k in keys) |
| | return "^(" + "|".join(escaped) + ")$" |
| |
|
| |
|
| | class CircuitGPTPreTrainedModel(PreTrainedModel): |
| | config_class = CircuitGPTConfig |
| | base_model_prefix = "circuit_model" |
| | circuit_model: GPT |
| |
|
| | def __init__(self, config: CircuitGPTConfig, *inputs, **kwargs) -> None: |
| | super().__init__(config, *inputs, **kwargs) |
| |
|
| | def get_input_embeddings(self) -> nn.Module: |
| | return self.circuit_model.transformer.wte |
| |
|
| | def set_input_embeddings(self, value: nn.Module) -> None: |
| | self.circuit_model.transformer.wte = value |
| |
|
| | def get_output_embeddings(self) -> nn.Module: |
| | return self.circuit_model.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
| | self.circuit_model.lm_head = new_embeddings |
| |
|
| |
|
| | class CircuitGPTForCausalLM(CircuitGPTPreTrainedModel, GenerationMixin): |
| | """ |
| | Hugging Face-compatible wrapper around `circuit_sparsity.gpt.GPT`. |
| | All math happens inside the original module so parity is guaranteed. |
| | """ |
| |
|
| | def __init__(self, config: CircuitGPTConfig, circuit_model: GPT | None = None) -> None: |
| | super().__init__(config) |
| |
|
| | if circuit_model is None: |
| | self.circuit_model = GPT(config.to_circuit_config()) |
| | self.post_init() |
| | else: |
| | self.circuit_model = circuit_model |
| |
|
| | |
| | |
| | |
| | @classmethod |
| | def from_circuit_model(cls, circuit_model: GPT) -> "CircuitGPTForCausalLM": |
| | config = CircuitGPTConfig.from_circuit_config(circuit_model.config) |
| | return cls(config, circuit_model=circuit_model) |
| |
|
| | |
| | |
| | |
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | labels: torch.LongTensor | None = None, |
| | output_activations: Sequence[str] | None = None, |
| | return_dict: bool | None = None, |
| | use_cache: bool | None = None, |
| | output_attentions: bool | None = None, |
| | output_hidden_states: bool | None = None, |
| | **kwargs, |
| | ) -> CircuitGPTCausalLMOutput: |
| | |
| | remaining_kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| | if remaining_kwargs: |
| | unsupported = ", ".join(remaining_kwargs.keys()) |
| | raise ValueError(f"Unsupported arguments for CircuitGPTForCausalLM: {unsupported}") |
| |
|
| | if input_ids.size(-1) > self.config.block_size: |
| | raise ValueError( |
| | f"Sequence length {input_ids.size(-1)} exceeds block size {self.config.block_size}" |
| | ) |
| |
|
| | if output_activations: |
| | regex = _activations_regex(output_activations) |
| | with hook_recorder(regex=regex) as recorded: |
| | logits, loss, _ = self.circuit_model(input_ids, targets=labels) |
| | activations = {key: recorded[key] for key in output_activations if key in recorded} |
| | else: |
| | activations = None |
| | logits, loss, _ = self.circuit_model(input_ids, targets=labels) |
| |
|
| | if labels is None: |
| | loss = None |
| |
|
| | return CircuitGPTCausalLMOutput( |
| | loss=loss, |
| | logits=logits, |
| | activations=activations, |
| | ) |
| |
|
| | |
| | |
| | |
| | def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs): |
| | if input_ids.size(-1) > self.config.block_size: |
| | input_ids = input_ids[:, -self.config.block_size :] |
| | return {"input_ids": input_ids} |
| |
|
| | def _reorder_cache(self, past, beam_idx): |
| | |
| | return past |
| |
|