File size: 4,522 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e64e71
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""E2E-style smoke coverage for the GRPO training notebook."""

from __future__ import annotations

import json
from pathlib import Path

from sql_env.training.config import GRPOConfig
from sql_env.training.notebook_pipeline import (
    build_trainer,
    run_training_with_metrics,
    sample_random_baseline,
)
from sql_env.training.data_loading import filter_questions_by_difficulty


NOTEBOOK_PATH = Path("notebooks/train_grpo.ipynb")


def _read_notebook() -> dict:
    return json.loads(NOTEBOOK_PATH.read_text(encoding="utf-8"))


def _code_sources(notebook: dict) -> list[str]:
    cells = notebook.get("cells", [])
    return [
        "".join(cell.get("source", []))
        for cell in cells
        if cell.get("cell_type") == "code"
    ]


def test_training_notebook_smoke_structure() -> None:
    """Notebook includes the core GRPO training flow cells."""

    assert NOTEBOOK_PATH.exists(), "notebooks/train_grpo.ipynb must exist"

    notebook = _read_notebook()
    sources = "\n".join(_code_sources(notebook))

    assert "GRPOConfig(" in sources
    assert "load_model_and_tokenizer(config.model_name)" in sources
    assert "grpo_trainer_cls=GRPOTrainer" in sources
    assert "run_training_with_metrics" in sources
    assert "matplotlib.pyplot as plt" in sources

    before_index = sources.find("before_rollouts = sample_random_baseline")
    train_index = sources.find("run_training_with_metrics(trainer)")
    assert before_index != -1
    assert train_index != -1
    assert before_index < train_index


def test_question_filtering_by_difficulty() -> None:
    """Difficulty filtering keeps only questions in the allowed set."""

    questions = [
        {"question_text": "q1", "difficulty": "easy"},
        {"question_text": "q2", "difficulty": "medium"},
        {"question_text": "q3", "difficulty": "hard"},
    ]

    filtered = filter_questions_by_difficulty(questions, ["easy"])
    assert [item["question_text"] for item in filtered] == ["q1"]


class _FakeTRLConfig:
    def __init__(self, **kwargs):
        self.kwargs = kwargs


class _FakeTrainer:
    def __init__(
        self,
        *,
        model,
        processing_class,
        args,
        train_dataset,
        reward_funcs,
    ) -> None:
        self.model = model
        self.processing_class = processing_class
        self.args = args
        self.train_dataset = train_dataset
        self.reward_funcs = reward_funcs
        self.state = type("State", (), {"log_history": []})()
        self.train_called = False

    def train(self) -> dict[str, str]:
        self.train_called = True
        self.state.log_history = [{"step": 1, "reward": 0.25}]
        return {"status": "ok"}


class _FakeTokenizer:
    def apply_chat_template(
        self,
        messages: list[dict[str, str]],
        tokenize: bool = False,
        add_generation_prompt: bool = True,
    ) -> str:
        del messages
        del tokenize
        del add_generation_prompt
        return "prompt"


class _FakeModel:
    def __init__(self) -> None:
        self._count = 0

    def generate(self, prompt: str, max_new_tokens: int) -> str:
        del prompt
        del max_new_tokens
        self._count += 1
        if self._count == 1:
            return "QUERY: SELECT 1"
        return "ANSWER: 42"


def test_notebook_pipeline_executes_training_step() -> None:
    """Notebook pipeline helper builds trainer and executes train()."""

    config = GRPOConfig(
        questions_path="data/questions/questions_train.json",
        db_dir="data/databases",
        output_dir="outputs/grpo_test",
        step_budget=2,
    )
    tokenizer = _FakeTokenizer()
    model = _FakeModel()

    trainer = build_trainer(
        model=model,
        tokenizer=tokenizer,
        prompts=[{"prompt": "Count rows"}],
        config=config,
        trl_grpo_config_cls=_FakeTRLConfig,
        grpo_trainer_cls=_FakeTrainer,
        reward_funcs=[],
    )

    output, steps, rewards = run_training_with_metrics(trainer)

    assert trainer.train_called is True
    assert output == {"status": "ok"}
    assert steps == [1]
    assert rewards == [0.25]


def test_random_baseline_transcripts_are_generated() -> None:
    """Random baseline helper generates readable transcripts per prompt."""

    baseline = sample_random_baseline(["q1", "q2"], step_budget=3, seed=7)
    assert len(baseline) == 2
    assert all(item["metadata"]["policy"] == "random" for item in baseline)
    assert all(item["completion"] for item in baseline)