GradeM8 Deploy commited on
Commit
40f6738
·
1 Parent(s): 28b7952

refactor: complete HuggingFace naming migration and code quality improvements

Browse files

- Rename _get_deepinfra_config() to _get_huggingface_config()
- Update all tests to use HuggingFace naming conventions
- Fix parsing.py to use HF_MODEL_DEFAULT
- Add DEEPINFRA_API_URL back for OCR module
- Fix Document fallback in conversion.py
- Fix mock function signatures in test_orchestration.py
- Add filename field to GradingResultWithStatus type tests

All 285 unit tests passing

.env.example CHANGED
@@ -1,6 +1,6 @@
1
- # DeepInfra API Configuration
2
- DEEPINFRA_API_KEY=your_api_key_here
3
- DEEPINFRA_MODEL=meta-llama/Meta-Llama-3.1-70B-Instruct
4
 
5
  # Grading Configuration
6
  MAX_TOKENS=2048
 
1
+ # HuggingFace API Configuration
2
+ HUGGINGFACE_API_KEY=your_api_key_here
3
+ HF_MODEL_PRIMARY=meta-llama/Llama-2-70b-chat-hf
4
 
5
  # Grading Configuration
6
  MAX_TOKENS=2048
ai_router/__init__.py CHANGED
@@ -30,7 +30,7 @@ from .client import generate_grading
30
  from .orchestration import generate_batch_grading
31
 
32
  # Re-export underscore-prefixed internals used by tests for compatibility
33
- from .client import _get_deepinfra_config
34
  from .orchestration import (
35
  _build_grading_error_result,
36
  _build_grading_success_result,
@@ -49,7 +49,7 @@ __all__ = [
49
  # OCR sub-module (access as ai_router.ocr)
50
  "ocr",
51
  # Underscore-prefixed internals (test compatibility)
52
- "_get_deepinfra_config",
53
  "_build_grading_success_result",
54
  "_build_grading_error_result",
55
  "_calculate_batch_stats",
 
30
  from .orchestration import generate_batch_grading
31
 
32
  # Re-export underscore-prefixed internals used by tests for compatibility
33
+ from .client import _get_huggingface_config
34
  from .orchestration import (
35
  _build_grading_error_result,
36
  _build_grading_success_result,
 
49
  # OCR sub-module (access as ai_router.ocr)
50
  "ocr",
51
  # Underscore-prefixed internals (test compatibility)
52
+ "_get_huggingface_config",
53
  "_build_grading_success_result",
54
  "_build_grading_error_result",
55
  "_calculate_batch_stats",
ai_router/client.py CHANGED
@@ -34,10 +34,8 @@ BACKOFF_MULTIPLIER = 2.0
34
  HF_API_URL = "https://api-inference.huggingface.co/models"
35
 
36
 
37
- def _get_deepinfra_config() -> tuple[str, str, int, float]:
38
  """Get HuggingFace API configuration from environment or defaults.
39
-
40
- Legacy name kept for backwards compatibility with tests.
41
 
42
  Returns:
43
  Tuple of (api_key, model, max_tokens, temperature)
@@ -90,7 +88,7 @@ async def generate_grading(content: str, rubric: str) -> GradingResult:
90
  """
91
  from .prompt import build_grading_prompt
92
 
93
- api_key, model, max_tokens, temperature = _get_deepinfra_config()
94
 
95
  # Build the prompt
96
  prompt = build_grading_prompt(content, rubric)
 
34
  HF_API_URL = "https://api-inference.huggingface.co/models"
35
 
36
 
37
+ def _get_huggingface_config() -> tuple[str, str, int, float]:
38
  """Get HuggingFace API configuration from environment or defaults.
 
 
39
 
40
  Returns:
41
  Tuple of (api_key, model, max_tokens, temperature)
 
88
  """
89
  from .prompt import build_grading_prompt
90
 
91
+ api_key, model, max_tokens, temperature = _get_huggingface_config()
92
 
93
  # Build the prompt
94
  prompt = build_grading_prompt(content, rubric)
ai_router/parsing.py CHANGED
@@ -122,7 +122,7 @@ def _validate_grading_result(result: dict) -> GradingResult:
122
  "strengths": strengths,
123
  "improvements": improvements,
124
  "feedback": str(result.get("feedback", "")),
125
- "details": f"Graded using {config.DEEPINFRA_MODEL_DEFAULT}",
126
  }
127
 
128
 
 
122
  "strengths": strengths,
123
  "improvements": improvements,
124
  "feedback": str(result.get("feedback", "")),
125
+ "details": f"Graded using {config.HF_MODEL_DEFAULT}",
126
  }
127
 
128
 
app.py CHANGED
@@ -20,7 +20,6 @@ Just upload, review the rubric, and let GradeM8 do the rest!
20
  from __future__ import annotations
21
 
22
  import asyncio
23
- import json
24
  import logging
25
  from pathlib import Path
26
  from typing import TYPE_CHECKING, Any
@@ -33,7 +32,6 @@ import config
33
  import document
34
  from exceptions import DocumentError, DocumentConversionError, UnsupportedFileTypeError
35
  from ui.components import (
36
- get_accessibility_css,
37
  create_results_cards,
38
  create_error_html,
39
  create_waiting_html,
@@ -53,6 +51,9 @@ logger = logging.getLogger(__name__)
53
  # Empty DataFrame template for consistent empty results
54
  _EMPTY_RESULTS_DF = pd.DataFrame(columns=["Student", "Score", "Status"])
55
 
 
 
 
56
 
57
  # =============================================================================
58
  # Built-in Rubric Templates
@@ -153,8 +154,6 @@ def get_rubric_content(rubric_key: str) -> str:
153
  async def process_batch_submissions(
154
  rubric: str,
155
  file_objs: list[Any],
156
- text_size: str = "standard",
157
- high_contrast: bool = False,
158
  progress: gr.Progress = gr.Progress(), # noqa: B008
159
  ) -> tuple[str, str, str, bytes, pd.DataFrame, str]:
160
  """
@@ -247,7 +246,7 @@ def _process_extraction_results(
247
 
248
  for result in extract_results:
249
  if isinstance(result, Exception):
250
- logger.error(f"Could not read file: {result}")
251
  continue
252
  filename, content, status = result
253
  if status == "success" and content:
@@ -324,8 +323,12 @@ async def _extract_single(file_obj: Any, index: int) -> tuple[str, str, str]:
324
  # Get the actual file path and read content
325
  file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
326
 
327
- with open(file_path, "rb") as f:
328
- content_bytes = f.read()
 
 
 
 
329
 
330
  # Extract text using format-specific handlers
331
  text_content = document.extract_text(content_bytes, filename)
@@ -340,10 +343,10 @@ async def _extract_single(file_obj: Any, index: int) -> tuple[str, str, str]:
340
  return filename, text_content, "success"
341
 
342
  except DocumentError as e:
343
- logger.error(f"Document error for {filename}: {e}")
344
  return filename, "", f"Could not read file: {e}"
345
- except Exception as e:
346
- logger.error(f"Failed to extract text from {filename}: {e}")
347
  return filename, "", "Unexpected error. Please try a different file format."
348
 
349
 
@@ -376,8 +379,8 @@ def _generate_word_report(batch_result: dict[str, Any], rubric: str) -> bytes:
376
  results=batch_result["results"],
377
  rubric=rubric,
378
  )
379
- except Exception as e:
380
- logger.error(f"Failed to create Word report: {e}")
381
  return b""
382
 
383
 
@@ -412,13 +415,13 @@ def process_document_conversion(
412
  return converted_bytes, output_filename, f"✓ Successfully converted to {output_format}"
413
 
414
  except UnsupportedFileTypeError as e:
415
- logger.error(f"Unsupported conversion: {e}")
416
  return None, "", f"❌ {e}"
417
  except DocumentConversionError as e:
418
- logger.error(f"Conversion failed: {e}")
419
  return None, "", f"❌ Conversion failed: {e}"
420
- except Exception as e:
421
- logger.error(f"Unexpected error during conversion: {e}")
422
  return None, "", "❌ An unexpected error occurred."
423
 
424
 
@@ -448,8 +451,8 @@ def process_images_to_pdf(
448
 
449
  return pdf_bytes, "combined_images.pdf", f"✓ Combined {len(file_objs)} images into PDF"
450
 
451
- except Exception as e:
452
- logger.error(f"Failed to convert images to PDF: {e}")
453
  return None, "", f"❌ Could not create PDF: {e}"
454
 
455
 
@@ -476,7 +479,7 @@ def create_interface() -> gr.Blocks:
476
  """)
477
 
478
  with gr.Row():
479
- text_size_setting = gr.Radio(
480
  label="Text Size",
481
  choices=[
482
  ("Standard", "standard"),
@@ -487,17 +490,13 @@ def create_interface() -> gr.Blocks:
487
  elem_id="text-size-setting",
488
  )
489
 
490
- high_contrast_setting = gr.Checkbox(
491
  label="High Contrast Mode",
492
  value=False,
493
  elem_id="high-contrast-setting",
494
  info="Increases contrast for better visibility",
495
  )
496
 
497
- # Apply accessibility settings
498
- def apply_accessibility_settings(text_size: str, high_contrast: bool) -> str:
499
- return get_accessibility_css(text_size, high_contrast)
500
-
501
  # Header with better structure
502
  gr.Markdown("""
503
  <header role="banner">
@@ -556,7 +555,7 @@ def create_interface() -> gr.Blocks:
556
 
557
  file_input = gr.File(
558
  label="Upload papers (PDF, Word, or images)",
559
- file_types=[".pdf", ".docx", ".doc", ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"],
560
  file_count="multiple",
561
  elem_id="file-upload",
562
  )
@@ -740,7 +739,7 @@ def create_interface() -> gr.Blocks:
740
  with gr.Column():
741
  convert_file_input = gr.File(
742
  label="Upload document to convert",
743
- file_types=[".pdf", ".docx", ".doc", ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"],
744
  file_count="single",
745
  )
746
 
 
20
  from __future__ import annotations
21
 
22
  import asyncio
 
23
  import logging
24
  from pathlib import Path
25
  from typing import TYPE_CHECKING, Any
 
32
  import document
33
  from exceptions import DocumentError, DocumentConversionError, UnsupportedFileTypeError
34
  from ui.components import (
 
35
  create_results_cards,
36
  create_error_html,
37
  create_waiting_html,
 
51
  # Empty DataFrame template for consistent empty results
52
  _EMPTY_RESULTS_DF = pd.DataFrame(columns=["Student", "Score", "Status"])
53
 
54
+ # Supported file types for grading and conversion
55
+ SUPPORTED_FILE_TYPES = [".pdf", ".docx", ".doc", ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"]
56
+
57
 
58
  # =============================================================================
59
  # Built-in Rubric Templates
 
154
  async def process_batch_submissions(
155
  rubric: str,
156
  file_objs: list[Any],
 
 
157
  progress: gr.Progress = gr.Progress(), # noqa: B008
158
  ) -> tuple[str, str, str, bytes, pd.DataFrame, str]:
159
  """
 
246
 
247
  for result in extract_results:
248
  if isinstance(result, Exception):
249
+ logger.error("Could not read file: %s", result)
250
  continue
251
  filename, content, status = result
252
  if status == "success" and content:
 
323
  # Get the actual file path and read content
324
  file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
325
 
326
+ # Use asyncio.to_thread for sync file I/O in async context
327
+ def read_file_sync():
328
+ with open(file_path, "rb") as f:
329
+ return f.read()
330
+
331
+ content_bytes = await asyncio.to_thread(read_file_sync)
332
 
333
  # Extract text using format-specific handlers
334
  text_content = document.extract_text(content_bytes, filename)
 
343
  return filename, text_content, "success"
344
 
345
  except DocumentError as e:
346
+ logger.error("Document error for %s: %s", filename, e)
347
  return filename, "", f"Could not read file: {e}"
348
+ except (OSError, IOError) as e:
349
+ logger.error("Failed to extract text from %s: %s", filename, e)
350
  return filename, "", "Unexpected error. Please try a different file format."
351
 
352
 
 
379
  results=batch_result["results"],
380
  rubric=rubric,
381
  )
382
+ except (OSError, IOError, ValueError) as e:
383
+ logger.error("Failed to create Word report: %s", e)
384
  return b""
385
 
386
 
 
415
  return converted_bytes, output_filename, f"✓ Successfully converted to {output_format}"
416
 
417
  except UnsupportedFileTypeError as e:
418
+ logger.error("Unsupported conversion: %s", e)
419
  return None, "", f"❌ {e}"
420
  except DocumentConversionError as e:
421
+ logger.error("Conversion failed: %s", e)
422
  return None, "", f"❌ Conversion failed: {e}"
423
+ except (OSError, IOError, ValueError) as e:
424
+ logger.error("Unexpected error during conversion: %s", e)
425
  return None, "", "❌ An unexpected error occurred."
426
 
427
 
 
451
 
452
  return pdf_bytes, "combined_images.pdf", f"✓ Combined {len(file_objs)} images into PDF"
453
 
454
+ except (OSError, IOError, ValueError) as e:
455
+ logger.error("Failed to convert images to PDF: %s", e)
456
  return None, "", f"❌ Could not create PDF: {e}"
457
 
458
 
 
479
  """)
480
 
481
  with gr.Row():
482
+ _text_size_setting = gr.Radio( # noqa: F841 - Future accessibility feature
483
  label="Text Size",
484
  choices=[
485
  ("Standard", "standard"),
 
490
  elem_id="text-size-setting",
491
  )
492
 
493
+ _high_contrast_setting = gr.Checkbox( # noqa: F841 - Future accessibility feature
494
  label="High Contrast Mode",
495
  value=False,
496
  elem_id="high-contrast-setting",
497
  info="Increases contrast for better visibility",
498
  )
499
 
 
 
 
 
500
  # Header with better structure
501
  gr.Markdown("""
502
  <header role="banner">
 
555
 
556
  file_input = gr.File(
557
  label="Upload papers (PDF, Word, or images)",
558
+ file_types=SUPPORTED_FILE_TYPES,
559
  file_count="multiple",
560
  elem_id="file-upload",
561
  )
 
739
  with gr.Column():
740
  convert_file_input = gr.File(
741
  label="Upload document to convert",
742
+ file_types=SUPPORTED_FILE_TYPES,
743
  file_count="single",
744
  )
745
 
config.py CHANGED
@@ -46,6 +46,9 @@ OCR_TIMEOUT_SECONDS: float = 300.0
46
  # HuggingFace Inference API endpoint URL
47
  HF_API_URL: str = "https://api-inference.huggingface.co/models"
48
 
 
 
 
49
  # =============================================================================
50
  # Concurrency Settings
51
  # =============================================================================
@@ -348,18 +351,18 @@ h3 {{
348
  # =============================================================================
349
 
350
  # Required environment variable:
351
- # - DEEPINFRA_API_KEY: Get from https://deepinfra.com/
352
  #
353
  # Optional overrides:
354
- # - DEEPINFRA_MODEL: Override the default model
355
- # - DEEPINFRA_MAX_TOKENS: Override max tokens
356
- # - DEEPINFRA_TEMPERATURE: Override temperature
357
  # - OCR_FALLBACK_ENABLED: Enable/disable OCR fallback (true/false)
358
  ENV_VARS: dict[str, str] = {
359
- "deepinfra_key": "DEEPINFRA_API_KEY",
360
- "deepinfra_model": "DEEPINFRA_MODEL",
361
- "deepinfra_max_tokens": "DEEPINFRA_MAX_TOKENS",
362
- "deepinfra_temperature": "DEEPINFRA_TEMPERATURE",
363
  "ocr_fallback_enabled": "OCR_FALLBACK_ENABLED",
364
  }
365
 
 
46
  # HuggingFace Inference API endpoint URL
47
  HF_API_URL: str = "https://api-inference.huggingface.co/models"
48
 
49
+ # DeepInfra API URL (used for OCR with DeepSeek-OCR model)
50
+ DEEPINFRA_API_URL: str = "https://api.deepinfra.com/v1/openai/chat/completions"
51
+
52
  # =============================================================================
53
  # Concurrency Settings
54
  # =============================================================================
 
351
  # =============================================================================
352
 
353
  # Required environment variable:
354
+ # - HUGGINGFACE_API_KEY: Get from https://huggingface.co/settings/tokens
355
  #
356
  # Optional overrides:
357
+ # - HF_MODEL_PRIMARY: Override the default model
358
+ # - HF_MAX_TOKENS: Override max tokens
359
+ # - HF_TEMPERATURE: Override temperature
360
  # - OCR_FALLBACK_ENABLED: Enable/disable OCR fallback (true/false)
361
  ENV_VARS: dict[str, str] = {
362
+ "huggingface_key": "HUGGINGFACE_API_KEY",
363
+ "hf_model": "HF_MODEL_PRIMARY",
364
+ "hf_max_tokens": "HF_MAX_TOKENS",
365
+ "hf_temperature": "HF_TEMPERATURE",
366
  "ocr_fallback_enabled": "OCR_FALLBACK_ENABLED",
367
  }
368
 
document/conversion.py CHANGED
@@ -36,6 +36,7 @@ try:
36
  DOCX_SUPPORT = True
37
  except ImportError:
38
  DOCX_SUPPORT = False
 
39
 
40
  # Image support
41
  try:
 
36
  DOCX_SUPPORT = True
37
  except ImportError:
38
  DOCX_SUPPORT = False
39
+ Document = None # type: ignore
40
 
41
  # Image support
42
  try:
exceptions.py CHANGED
@@ -88,48 +88,3 @@ class AIServiceError(GradingError):
88
 
89
  class ConfigurationError(GradeM8Error):
90
  """Raised when there's a configuration issue."""
91
- pass
92
-
93
-
94
- # =============================================================================
95
- # AI Grading Errors
96
- # =============================================================================
97
-
98
- class GradingError(GradeM8Error):
99
- """Base exception for AI grading errors."""
100
- pass
101
-
102
-
103
- class ResponseParseError(GradingError):
104
- """Raised when the AI response cannot be parsed as valid JSON."""
105
-
106
- def __init__(self, message: str, raw_response: str | None = None) -> None:
107
- self.raw_response = raw_response
108
- super().__init__(message)
109
-
110
-
111
- class InvalidResponseError(GradingError):
112
- """Raised when the AI response is valid JSON but missing required fields."""
113
- pass
114
-
115
-
116
- class APIKeyError(GradingError):
117
- """Raised when the API key is missing or invalid."""
118
- pass
119
-
120
-
121
- class AIServiceError(GradingError):
122
- """Raised when the AI service returns an error."""
123
-
124
- def __init__(self, message: str, status_code: int | None = None) -> None:
125
- self.status_code = status_code
126
- super().__init__(message)
127
-
128
-
129
- # =============================================================================
130
- # Configuration Errors
131
- # =============================================================================
132
-
133
- class ConfigurationError(GradeM8Error):
134
- """Raised when there's a configuration issue."""
135
- pass
 
88
 
89
  class ConfigurationError(GradeM8Error):
90
  """Raised when there's a configuration issue."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/conftest.py CHANGED
@@ -93,14 +93,14 @@ from unittest.mock import Mock, AsyncMock
93
  @pytest.fixture(autouse=True)
94
  def setup_test_env() -> Generator[None, None, None]:
95
  """Set up test environment with mock API key."""
96
- old_key = os.environ.get("DEEPINFRA_API_KEY")
97
- os.environ["DEEPINFRA_API_KEY"] = "test_api_key_12345"
98
  yield
99
  # Cleanup
100
  if old_key is not None:
101
- os.environ["DEEPINFRA_API_KEY"] = old_key
102
- elif "DEEPINFRA_API_KEY" in os.environ:
103
- del os.environ["DEEPINFRA_API_KEY"]
104
 
105
 
106
  @pytest.fixture
 
93
  @pytest.fixture(autouse=True)
94
  def setup_test_env() -> Generator[None, None, None]:
95
  """Set up test environment with mock API key."""
96
+ old_key = os.environ.get("HUGGINGFACE_API_KEY")
97
+ os.environ["HUGGINGFACE_API_KEY"] = "test_api_key_12345"
98
  yield
99
  # Cleanup
100
  if old_key is not None:
101
+ os.environ["HUGGINGFACE_API_KEY"] = old_key
102
+ elif "HUGGINGFACE_API_KEY" in os.environ:
103
+ del os.environ["HUGGINGFACE_API_KEY"]
104
 
105
 
106
  @pytest.fixture
tests/unit/test_ai_router.py CHANGED
@@ -146,16 +146,16 @@ class TestTransformToFinalFormat:
146
  assert results[0]["score"] == 80
147
 
148
 
149
- class TestGetDeepinfraConfig:
150
- """Test _get_deepinfra_config helper."""
151
 
152
  def test_uses_environment_variables(self, monkeypatch):
153
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
154
- monkeypatch.setenv("DEEPINFRA_MODEL", "test_model")
155
- monkeypatch.setenv("DEEPINFRA_MAX_TOKENS", "1000")
156
- monkeypatch.setenv("DEEPINFRA_TEMPERATURE", "0.5")
157
 
158
- api_key, model, max_tokens, temperature = ai_router._get_deepinfra_config()
159
 
160
  assert api_key == "test_key"
161
  assert model == "test_model"
@@ -163,22 +163,22 @@ class TestGetDeepinfraConfig:
163
  assert temperature == 0.5
164
 
165
  def test_uses_defaults_for_optional_vars(self, monkeypatch):
166
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
167
  # Unset optional vars
168
- monkeypatch.delenv("DEEPINFRA_MODEL", raising=False)
169
- monkeypatch.delenv("DEEPINFRA_MAX_TOKENS", raising=False)
170
- monkeypatch.delenv("DEEPINFRA_TEMPERATURE", raising=False)
171
 
172
- api_key, model, max_tokens, temperature = ai_router._get_deepinfra_config()
173
 
174
  assert api_key == "test_key"
175
- assert model == ai_router.config.DEEPINFRA_MODEL_DEFAULT
176
  assert max_tokens == ai_router.config.MAX_TOKENS
177
  assert temperature == ai_router.config.TEMPERATURE
178
 
179
  def test_raises_when_api_key_missing(self, monkeypatch):
180
- monkeypatch.delenv("DEEPINFRA_API_KEY", raising=False)
181
 
182
  from exceptions import APIKeyError
183
  with pytest.raises(APIKeyError):
184
- ai_router._get_deepinfra_config()
 
146
  assert results[0]["score"] == 80
147
 
148
 
149
+ class TestGetHuggingfaceConfig:
150
+ """Test _get_huggingface_config helper."""
151
 
152
  def test_uses_environment_variables(self, monkeypatch):
153
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
154
+ monkeypatch.setenv("HF_MODEL_PRIMARY", "test_model")
155
+ monkeypatch.setenv("HF_MAX_TOKENS", "1000")
156
+ monkeypatch.setenv("HF_TEMPERATURE", "0.5")
157
 
158
+ api_key, model, max_tokens, temperature = ai_router._get_huggingface_config()
159
 
160
  assert api_key == "test_key"
161
  assert model == "test_model"
 
163
  assert temperature == 0.5
164
 
165
  def test_uses_defaults_for_optional_vars(self, monkeypatch):
166
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
167
  # Unset optional vars
168
+ monkeypatch.delenv("HF_MODEL_PRIMARY", raising=False)
169
+ monkeypatch.delenv("HF_MAX_TOKENS", raising=False)
170
+ monkeypatch.delenv("HF_TEMPERATURE", raising=False)
171
 
172
+ api_key, model, max_tokens, temperature = ai_router._get_huggingface_config()
173
 
174
  assert api_key == "test_key"
175
+ assert model == ai_router.config.HF_MODEL_DEFAULT
176
  assert max_tokens == ai_router.config.MAX_TOKENS
177
  assert temperature == ai_router.config.TEMPERATURE
178
 
179
  def test_raises_when_api_key_missing(self, monkeypatch):
180
+ monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
181
 
182
  from exceptions import APIKeyError
183
  with pytest.raises(APIKeyError):
184
+ ai_router._get_huggingface_config()
tests/unit/test_client.py CHANGED
@@ -9,29 +9,29 @@ import pytest
9
  import respx
10
  from httpx import Response
11
 
12
- from ai_router.client import _get_deepinfra_config, generate_grading
13
  from exceptions import APIKeyError, AIServiceError, ResponseParseError
14
  import config
15
 
16
 
17
- class TestGetDeepinfraConfig:
18
- """Tests for _get_deepinfra_config function."""
19
 
20
  def test_raises_when_api_key_missing(self, monkeypatch):
21
- """Test error when DEEPINFRA_API_KEY is not set."""
22
- monkeypatch.delenv("DEEPINFRA_API_KEY", raising=False)
23
 
24
- with pytest.raises(APIKeyError, match="DEEPINFRA_API_KEY"):
25
- _get_deepinfra_config()
26
 
27
  def test_uses_environment_variables(self, monkeypatch):
28
  """Test using environment variables."""
29
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
30
- monkeypatch.setenv("DEEPINFRA_MODEL", "test_model")
31
- monkeypatch.setenv("DEEPINFRA_MAX_TOKENS", "1000")
32
- monkeypatch.setenv("DEEPINFRA_TEMPERATURE", "0.5")
33
 
34
- api_key, model, max_tokens, temperature = _get_deepinfra_config()
35
 
36
  assert api_key == "test_key"
37
  assert model == "test_model"
@@ -40,24 +40,24 @@ class TestGetDeepinfraConfig:
40
 
41
  def test_uses_defaults_for_optional_vars(self, monkeypatch):
42
  """Test using defaults when optional vars not set."""
43
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
44
- monkeypatch.delenv("DEEPINFRA_MODEL", raising=False)
45
- monkeypatch.delenv("DEEPINFRA_MAX_TOKENS", raising=False)
46
- monkeypatch.delenv("DEEPINFRA_TEMPERATURE", raising=False)
47
 
48
- api_key, model, max_tokens, temperature = _get_deepinfra_config()
49
 
50
  assert api_key == "test_key"
51
- assert model == config.DEEPINFRA_MODEL_DEFAULT
52
  assert max_tokens == config.MAX_TOKENS
53
  assert temperature == config.TEMPERATURE
54
 
55
  def test_empty_api_key_raises_error(self, monkeypatch):
56
  """Test error when API key is empty string."""
57
- monkeypatch.setenv("DEEPINFRA_API_KEY", "")
58
 
59
  with pytest.raises(APIKeyError):
60
- _get_deepinfra_config()
61
 
62
 
63
  class TestGenerateGrading:
@@ -67,7 +67,7 @@ class TestGenerateGrading:
67
  @pytest.mark.asyncio
68
  async def test_successful_grading(self, monkeypatch):
69
  """Test successful grading API call."""
70
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
71
 
72
  # Mock API response with valid grading JSON
73
  grading_response = {
@@ -79,10 +79,9 @@ class TestGenerateGrading:
79
  "feedback": "Well done!"
80
  }
81
 
82
- route = respx.post(config.DEEPINFRA_API_URL).mock(
83
- return_value=Response(200, json={
84
- "choices": [{"message": {"content": json.dumps(grading_response)}}]
85
- })
86
  )
87
 
88
  result = await generate_grading("Student submission", "Grade on clarity")
@@ -93,9 +92,9 @@ class TestGenerateGrading:
93
 
94
  @respx.mock
95
  @pytest.mark.asyncio
96
- async def test_grading_with_markdown_json(self, monkeypatch):
97
- """Test grading response in markdown code block."""
98
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
99
 
100
  grading_response = {
101
  "score": 90,
@@ -106,12 +105,11 @@ class TestGenerateGrading:
106
  "feedback": "Great job!"
107
  }
108
 
109
- content = f'```json\n{json.dumps(grading_response)}\n```'
 
110
 
111
- route = respx.post(config.DEEPINFRA_API_URL).mock(
112
- return_value=Response(200, json={
113
- "choices": [{"message": {"content": content}}]
114
- })
115
  )
116
 
117
  result = await generate_grading("Submission", "Rubric")
@@ -121,7 +119,7 @@ class TestGenerateGrading:
121
  @pytest.mark.asyncio
122
  async def test_missing_api_key_raises_error(self, monkeypatch):
123
  """Test error when API key is missing."""
124
- monkeypatch.delenv("DEEPINFRA_API_KEY", raising=False)
125
 
126
  with pytest.raises(APIKeyError):
127
  await generate_grading("Submission", "Rubric")
@@ -130,9 +128,9 @@ class TestGenerateGrading:
130
  @pytest.mark.asyncio
131
  async def test_http_401_error(self, monkeypatch):
132
  """Test 401 unauthorized error."""
133
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
134
 
135
- route = respx.post(config.DEEPINFRA_API_URL).mock(
136
  return_value=Response(401, text="Unauthorized")
137
  )
138
 
@@ -140,15 +138,15 @@ class TestGenerateGrading:
140
  await generate_grading("Submission", "Rubric")
141
 
142
  assert exc_info.value.status_code == 401
143
- assert "DeepInfra API error" in str(exc_info.value)
144
 
145
  @respx.mock
146
  @pytest.mark.asyncio
147
  async def test_http_429_rate_limit(self, monkeypatch):
148
  """Test 429 rate limit error."""
149
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
150
 
151
- route = respx.post(config.DEEPINFRA_API_URL).mock(
152
  return_value=Response(429, text="Rate limited")
153
  )
154
 
@@ -161,9 +159,9 @@ class TestGenerateGrading:
161
  @pytest.mark.asyncio
162
  async def test_http_500_error(self, monkeypatch):
163
  """Test 500 server error."""
164
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
165
 
166
- route = respx.post(config.DEEPINFRA_API_URL).mock(
167
  return_value=Response(500, text="Internal Server Error")
168
  )
169
 
@@ -172,69 +170,68 @@ class TestGenerateGrading:
172
 
173
  assert exc_info.value.status_code == 500
174
 
175
-
176
  @respx.mock
177
  @pytest.mark.asyncio
178
- async def test_invalid_json_response(self, monkeypatch):
179
- """Test invalid JSON in response."""
180
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- route = respx.post(config.DEEPINFRA_API_URL).mock(
183
- return_value=Response(200, json={
184
- "choices": [{"message": {"content": "not valid json"}}]
185
- })
186
- )
187
-
188
- with pytest.raises(ResponseParseError):
189
- await generate_grading("Submission", "Rubric")
190
 
191
  @respx.mock
192
  @pytest.mark.asyncio
193
- async def test_missing_choices_in_response(self, monkeypatch):
194
- """Test response missing choices field."""
195
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
196
 
197
- route = respx.post(config.DEEPINFRA_API_URL).mock(
198
- return_value=Response(200, json={"id": "test"})
199
  )
200
 
201
- with pytest.raises(AIServiceError, match="Invalid API response"):
202
  await generate_grading("Submission", "Rubric")
203
 
204
  @respx.mock
205
  @pytest.mark.asyncio
206
- async def test_empty_choices_in_response(self, monkeypatch):
207
- """Test response with empty choices array."""
208
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
209
 
210
- route = respx.post(config.DEEPINFRA_API_URL).mock(
211
- return_value=Response(200, json={"choices": []})
212
  )
213
 
214
- with pytest.raises(AIServiceError, match="Invalid API response"):
215
- await generate_grading("Submission", "Rubric")
216
-
217
- @respx.mock
218
- @pytest.mark.asyncio
219
- async def test_missing_message_content(self, monkeypatch):
220
- """Test response missing message content."""
221
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
222
-
223
- route = respx.post(config.DEEPINFRA_API_URL).mock(
224
- return_value=Response(200, json={
225
- "choices": [{"message": {}}]
226
- })
227
- )
228
-
229
- with pytest.raises(AIServiceError, match="Invalid API response"):
230
  await generate_grading("Submission", "Rubric")
231
 
232
  @respx.mock
233
  @pytest.mark.asyncio
234
  async def test_details_field_set(self, monkeypatch):
235
  """Test details field includes model info."""
236
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
237
- monkeypatch.setenv("DEEPINFRA_MODEL", "custom-model")
238
 
239
  grading_response = {
240
  "score": 75,
@@ -245,67 +242,51 @@ class TestGenerateGrading:
245
  "feedback": "Test feedback"
246
  }
247
 
248
- route = respx.post(config.DEEPINFRA_API_URL).mock(
249
- return_value=Response(200, json={
250
- "choices": [{"message": {"content": json.dumps(grading_response)}}]
251
- })
252
  )
253
 
254
  result = await generate_grading("Submission", "Rubric")
255
  assert "custom-model" in result["details"]
256
- assert "DeepInfra" in result["details"]
257
-
258
- @respx.mock
259
- @pytest.mark.asyncio
260
- async def test_response_format_is_json_object(self, monkeypatch):
261
- """Test that response_format is set to json_object."""
262
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
263
-
264
- captured_payload = {}
265
-
266
- def capture_request(request):
267
- captured_payload.update(json.loads(request.content))
268
- return Response(200, json={
269
- "choices": [{"message": {"content": '{"score": 80}'}}]
270
- })
271
-
272
- route = respx.post(config.DEEPINFRA_API_URL).mock(side_effect=capture_request)
273
-
274
- await generate_grading("Submission", "Rubric")
275
-
276
- assert captured_payload["response_format"]["type"] == "json_object"
277
 
278
  @respx.mock
279
  @pytest.mark.asyncio
280
  async def test_prompt_building(self, monkeypatch):
281
  """Test that prompt is built correctly."""
282
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
283
 
284
  captured_payload = {}
285
 
286
  def capture_request(request):
287
  captured_payload.update(json.loads(request.content))
288
- return Response(200, json={
289
- "choices": [{"message": {"content": '{"score": 80}'}}]
290
- })
291
-
292
- route = respx.post(config.DEEPINFRA_API_URL).mock(side_effect=capture_request)
 
 
 
 
 
 
293
 
294
  content = "Student essay content"
295
  rubric = "Grade on clarity and grammar"
296
  await generate_grading(content, rubric)
297
 
298
- messages = captured_payload["messages"]
299
- assert messages[0]["role"] == "system"
300
- assert messages[1]["role"] == "user"
301
- assert content in messages[1]["content"]
302
- assert rubric in messages[1]["content"]
303
 
304
  @respx.mock
305
  @pytest.mark.asyncio
306
  async def test_truncated_content_handling(self, monkeypatch):
307
  """Test handling of long content that needs truncation."""
308
- monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
309
 
310
  long_content = "a" * (config.MAX_PROMPT_CHARS + 1000)
311
 
@@ -318,11 +299,32 @@ class TestGenerateGrading:
318
  "feedback": "Test"
319
  }
320
 
321
- route = respx.post(config.DEEPINFRA_API_URL).mock(
322
- return_value=Response(200, json={
323
- "choices": [{"message": {"content": json.dumps(grading_response)}}]
324
- })
325
  )
326
 
327
  result = await generate_grading(long_content, "Rubric")
328
  assert result["score"] == 80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import respx
10
  from httpx import Response
11
 
12
+ from ai_router.client import _get_huggingface_config, generate_grading, HF_API_URL
13
  from exceptions import APIKeyError, AIServiceError, ResponseParseError
14
  import config
15
 
16
 
17
+ class TestGetHuggingfaceConfig:
18
+ """Tests for _get_huggingface_config function."""
19
 
20
  def test_raises_when_api_key_missing(self, monkeypatch):
21
+ """Test error when HUGGINGFACE_API_KEY is not set."""
22
+ monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
23
 
24
+ with pytest.raises(APIKeyError, match="HUGGINGFACE_API_KEY"):
25
+ _get_huggingface_config()
26
 
27
  def test_uses_environment_variables(self, monkeypatch):
28
  """Test using environment variables."""
29
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
30
+ monkeypatch.setenv("HF_MODEL_PRIMARY", "test_model")
31
+ monkeypatch.setenv("HF_MAX_TOKENS", "1000")
32
+ monkeypatch.setenv("HF_TEMPERATURE", "0.5")
33
 
34
+ api_key, model, max_tokens, temperature = _get_huggingface_config()
35
 
36
  assert api_key == "test_key"
37
  assert model == "test_model"
 
40
 
41
  def test_uses_defaults_for_optional_vars(self, monkeypatch):
42
  """Test using defaults when optional vars not set."""
43
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
44
+ monkeypatch.delenv("HF_MODEL_PRIMARY", raising=False)
45
+ monkeypatch.delenv("HF_MAX_TOKENS", raising=False)
46
+ monkeypatch.delenv("HF_TEMPERATURE", raising=False)
47
 
48
+ api_key, model, max_tokens, temperature = _get_huggingface_config()
49
 
50
  assert api_key == "test_key"
51
+ assert model == config.HF_MODEL_DEFAULT
52
  assert max_tokens == config.MAX_TOKENS
53
  assert temperature == config.TEMPERATURE
54
 
55
  def test_empty_api_key_raises_error(self, monkeypatch):
56
  """Test error when API key is empty string."""
57
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "")
58
 
59
  with pytest.raises(APIKeyError):
60
+ _get_huggingface_config()
61
 
62
 
63
  class TestGenerateGrading:
 
67
  @pytest.mark.asyncio
68
  async def test_successful_grading(self, monkeypatch):
69
  """Test successful grading API call."""
70
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
71
 
72
  # Mock API response with valid grading JSON
73
  grading_response = {
 
79
  "feedback": "Well done!"
80
  }
81
 
82
+ # HuggingFace returns a list with generated_text
83
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
84
+ return_value=Response(200, json=[{"generated_text": json.dumps(grading_response)}])
 
85
  )
86
 
87
  result = await generate_grading("Student submission", "Grade on clarity")
 
92
 
93
  @respx.mock
94
  @pytest.mark.asyncio
95
+ async def test_grading_with_embedded_json(self, monkeypatch):
96
+ """Test grading response with JSON embedded in text."""
97
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
98
 
99
  grading_response = {
100
  "score": 90,
 
105
  "feedback": "Great job!"
106
  }
107
 
108
+ # Response with JSON embedded in text
109
+ content = f'Here is the grading: {json.dumps(grading_response)} End of response.'
110
 
111
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
112
+ return_value=Response(200, json=[{"generated_text": content}])
 
 
113
  )
114
 
115
  result = await generate_grading("Submission", "Rubric")
 
119
  @pytest.mark.asyncio
120
  async def test_missing_api_key_raises_error(self, monkeypatch):
121
  """Test error when API key is missing."""
122
+ monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
123
 
124
  with pytest.raises(APIKeyError):
125
  await generate_grading("Submission", "Rubric")
 
128
  @pytest.mark.asyncio
129
  async def test_http_401_error(self, monkeypatch):
130
  """Test 401 unauthorized error."""
131
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
132
 
133
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
134
  return_value=Response(401, text="Unauthorized")
135
  )
136
 
 
138
  await generate_grading("Submission", "Rubric")
139
 
140
  assert exc_info.value.status_code == 401
141
+ assert "HuggingFace API error" in str(exc_info.value)
142
 
143
  @respx.mock
144
  @pytest.mark.asyncio
145
  async def test_http_429_rate_limit(self, monkeypatch):
146
  """Test 429 rate limit error."""
147
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
148
 
149
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
150
  return_value=Response(429, text="Rate limited")
151
  )
152
 
 
159
  @pytest.mark.asyncio
160
  async def test_http_500_error(self, monkeypatch):
161
  """Test 500 server error."""
162
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
163
 
164
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
165
  return_value=Response(500, text="Internal Server Error")
166
  )
167
 
 
170
 
171
  assert exc_info.value.status_code == 500
172
 
 
173
  @respx.mock
174
  @pytest.mark.asyncio
175
+ async def test_http_503_model_loading(self, monkeypatch):
176
+ """Test 503 model loading triggers retry."""
177
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
178
+
179
+ # 503 on first call, success on retry
180
+ call_count = 0
181
+
182
+ def mock_response(request):
183
+ nonlocal call_count
184
+ call_count += 1
185
+ if call_count == 1:
186
+ return Response(503, text="Model is loading")
187
+ grading_response = {
188
+ "score": 80,
189
+ "rubric_breakdown": {},
190
+ "summary": "Good",
191
+ "strengths": [],
192
+ "improvements": [],
193
+ "feedback": "Nice work"
194
+ }
195
+ return Response(200, json=[{"generated_text": json.dumps(grading_response)}])
196
+
197
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(side_effect=mock_response)
198
 
199
+ result = await generate_grading("Submission", "Rubric")
200
+ assert result["score"] == 80
201
+ assert call_count == 2
 
 
 
 
 
202
 
203
  @respx.mock
204
  @pytest.mark.asyncio
205
+ async def test_invalid_json_response(self, monkeypatch):
206
+ """Test invalid JSON in response."""
207
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
208
 
209
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
210
+ return_value=Response(200, json=[{"generated_text": "not valid json at all"}])
211
  )
212
 
213
+ with pytest.raises(ResponseParseError):
214
  await generate_grading("Submission", "Rubric")
215
 
216
  @respx.mock
217
  @pytest.mark.asyncio
218
+ async def test_empty_response_list(self, monkeypatch):
219
+ """Test response with empty list."""
220
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
221
 
222
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
223
+ return_value=Response(200, json=[])
224
  )
225
 
226
+ with pytest.raises(ResponseParseError):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  await generate_grading("Submission", "Rubric")
228
 
229
  @respx.mock
230
  @pytest.mark.asyncio
231
  async def test_details_field_set(self, monkeypatch):
232
  """Test details field includes model info."""
233
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
234
+ monkeypatch.setenv("HF_MODEL_PRIMARY", "custom-model")
235
 
236
  grading_response = {
237
  "score": 75,
 
242
  "feedback": "Test feedback"
243
  }
244
 
245
+ route = respx.post(f"{HF_API_URL}/custom-model").mock(
246
+ return_value=Response(200, json=[{"generated_text": json.dumps(grading_response)}])
 
 
247
  )
248
 
249
  result = await generate_grading("Submission", "Rubric")
250
  assert "custom-model" in result["details"]
251
+ assert "HuggingFace" in result["details"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  @respx.mock
254
  @pytest.mark.asyncio
255
  async def test_prompt_building(self, monkeypatch):
256
  """Test that prompt is built correctly."""
257
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
258
 
259
  captured_payload = {}
260
 
261
  def capture_request(request):
262
  captured_payload.update(json.loads(request.content))
263
+ grading_response = {
264
+ "score": 80,
265
+ "rubric_breakdown": {},
266
+ "summary": "Good",
267
+ "strengths": [],
268
+ "improvements": [],
269
+ "feedback": "Nice"
270
+ }
271
+ return Response(200, json=[{"generated_text": json.dumps(grading_response)}])
272
+
273
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(side_effect=capture_request)
274
 
275
  content = "Student essay content"
276
  rubric = "Grade on clarity and grammar"
277
  await generate_grading(content, rubric)
278
 
279
+ # HuggingFace format: inputs as full prompt, parameters for generation
280
+ assert "inputs" in captured_payload
281
+ assert content in captured_payload["inputs"]
282
+ assert rubric in captured_payload["inputs"]
283
+ assert "parameters" in captured_payload
284
 
285
  @respx.mock
286
  @pytest.mark.asyncio
287
  async def test_truncated_content_handling(self, monkeypatch):
288
  """Test handling of long content that needs truncation."""
289
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
290
 
291
  long_content = "a" * (config.MAX_PROMPT_CHARS + 1000)
292
 
 
299
  "feedback": "Test"
300
  }
301
 
302
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
303
+ return_value=Response(200, json=[{"generated_text": json.dumps(grading_response)}])
 
 
304
  )
305
 
306
  result = await generate_grading(long_content, "Rubric")
307
  assert result["score"] == 80
308
+
309
+ @respx.mock
310
+ @pytest.mark.asyncio
311
+ async def test_dict_response_format(self, monkeypatch):
312
+ """Test handling of dict response format (alternative HF response)."""
313
+ monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
314
+
315
+ grading_response = {
316
+ "score": 85,
317
+ "rubric_breakdown": {},
318
+ "summary": "Well done",
319
+ "strengths": [],
320
+ "improvements": [],
321
+ "feedback": "Good work"
322
+ }
323
+
324
+ # Some HF models return dict instead of list
325
+ route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
326
+ return_value=Response(200, json={"generated_text": json.dumps(grading_response)})
327
+ )
328
+
329
+ result = await generate_grading("Submission", "Rubric")
330
+ assert result["score"] == 85
tests/unit/test_config.py CHANGED
@@ -10,9 +10,9 @@ import config
10
  class TestConfigConstants:
11
  """Test configuration constants have expected values."""
12
 
13
- def test_deepinfra_model_default(self):
14
- """Test default DeepInfra model."""
15
- assert config.DEEPINFRA_MODEL_DEFAULT == "meta-llama/Meta-Llama-3.1-70B-Instruct"
16
 
17
  def test_max_tokens(self):
18
  """Test max tokens setting."""
@@ -60,13 +60,13 @@ class TestEnvironmentVariables:
60
 
61
  def test_env_vars_defined(self):
62
  """Test that environment variable names are defined."""
63
- assert "deepinfra_key" in config.ENV_VARS
64
- assert config.ENV_VARS["deepinfra_key"] == "DEEPINFRA_API_KEY"
65
 
66
  @pytest.mark.parametrize("env_var,config_key", [
67
- ("DEEPINFRA_MODEL", "deepinfra_model"),
68
- ("DEEPINFRA_MAX_TOKENS", "deepinfra_max_tokens"),
69
- ("DEEPINFRA_TEMPERATURE", "deepinfra_temperature"),
70
  ])
71
  def test_env_var_mappings(self, env_var, config_key):
72
  """Test environment variable mappings."""
@@ -120,8 +120,8 @@ class TestConfigTypes:
120
  def test_string_configs_are_strings(self):
121
  """Test string configuration values."""
122
  string_configs = [
123
- config.DEEPINFRA_MODEL_DEFAULT,
124
- config.DEEPINFRA_API_URL,
125
  config.LOG_LEVEL,
126
  config.LOG_FORMAT,
127
  config.GRADING_SYSTEM_PROMPT,
 
10
  class TestConfigConstants:
11
  """Test configuration constants have expected values."""
12
 
13
+ def test_hf_model_default(self):
14
+ """Test default HuggingFace model."""
15
+ assert config.HF_MODEL_DEFAULT == "meta-llama/Llama-2-70b-chat-hf"
16
 
17
  def test_max_tokens(self):
18
  """Test max tokens setting."""
 
60
 
61
  def test_env_vars_defined(self):
62
  """Test that environment variable names are defined."""
63
+ assert "huggingface_key" in config.ENV_VARS
64
+ assert config.ENV_VARS["huggingface_key"] == "HUGGINGFACE_API_KEY"
65
 
66
  @pytest.mark.parametrize("env_var,config_key", [
67
+ ("HF_MODEL_PRIMARY", "hf_model"),
68
+ ("HF_MAX_TOKENS", "hf_max_tokens"),
69
+ ("HF_TEMPERATURE", "hf_temperature"),
70
  ])
71
  def test_env_var_mappings(self, env_var, config_key):
72
  """Test environment variable mappings."""
 
120
  def test_string_configs_are_strings(self):
121
  """Test string configuration values."""
122
  string_configs = [
123
+ config.HF_MODEL_DEFAULT,
124
+ config.HF_API_URL,
125
  config.LOG_LEVEL,
126
  config.LOG_FORMAT,
127
  config.GRADING_SYSTEM_PROMPT,
tests/unit/test_orchestration.py CHANGED
@@ -180,7 +180,7 @@ class TestGenerateBatchGrading:
180
 
181
  call_count = 0
182
 
183
- async def mock_generate(**kwargs):
184
  nonlocal call_count
185
  call_count += 1
186
  if call_count == 1:
@@ -208,7 +208,7 @@ class TestGenerateBatchGrading:
208
  concurrent_calls = 0
209
  max_concurrent = 0
210
 
211
- async def mock_generate(**kwargs):
212
  nonlocal concurrent_calls, max_concurrent
213
  concurrent_calls += 1
214
  max_concurrent = max(max_concurrent, concurrent_calls)
 
180
 
181
  call_count = 0
182
 
183
+ async def mock_generate(content, rubric):
184
  nonlocal call_count
185
  call_count += 1
186
  if call_count == 1:
 
208
  concurrent_calls = 0
209
  max_concurrent = 0
210
 
211
+ async def mock_generate(content, rubric):
212
  nonlocal concurrent_calls, max_concurrent
213
  concurrent_calls += 1
214
  max_concurrent = max(max_concurrent, concurrent_calls)
tests/unit/test_parsing.py CHANGED
@@ -254,7 +254,7 @@ class TestValidateGradingResult:
254
  """Test details field is set with model info."""
255
  result = {"score": 80}
256
  validated = _validate_grading_result(result)
257
- assert config.DEEPINFRA_MODEL_DEFAULT in validated["details"]
258
 
259
 
260
  class TestParseGradingResponse:
 
254
  """Test details field is set with model info."""
255
  result = {"score": 80}
256
  validated = _validate_grading_result(result)
257
+ assert config.HF_MODEL_DEFAULT in validated["details"]
258
 
259
 
260
  class TestParseGradingResponse:
tests/unit/test_types.py CHANGED
@@ -26,7 +26,7 @@ class TestTypeDefinitions:
26
  def test_grading_result_with_status_required_fields(self):
27
  """Test GradingResultWithStatus has all required fields."""
28
  hints = get_type_hints(GradingResultWithStatus)
29
- required_fields = {"index", "score", "feedback", "summary", "rubric_breakdown", "strengths", "improvements", "details", "status"}
30
 
31
  assert set(hints.keys()) == required_fields
32
 
 
26
  def test_grading_result_with_status_required_fields(self):
27
  """Test GradingResultWithStatus has all required fields."""
28
  hints = get_type_hints(GradingResultWithStatus)
29
+ required_fields = {"index", "score", "feedback", "summary", "rubric_breakdown", "strengths", "improvements", "details", "status", "filename"}
30
 
31
  assert set(hints.keys()) == required_fields
32