Spaces:
Sleeping
Sleeping
GradeM8 Deploy commited on
Commit ·
40f6738
1
Parent(s): 28b7952
refactor: complete HuggingFace naming migration and code quality improvements
Browse files- Rename _get_deepinfra_config() to _get_huggingface_config()
- Update all tests to use HuggingFace naming conventions
- Fix parsing.py to use HF_MODEL_DEFAULT
- Add DEEPINFRA_API_URL back for OCR module
- Fix Document fallback in conversion.py
- Fix mock function signatures in test_orchestration.py
- Add filename field to GradingResultWithStatus type tests
All 285 unit tests passing
- .env.example +3 -3
- ai_router/__init__.py +2 -2
- ai_router/client.py +2 -4
- ai_router/parsing.py +1 -1
- app.py +25 -26
- config.py +11 -8
- document/conversion.py +1 -0
- exceptions.py +0 -45
- tests/conftest.py +5 -5
- tests/unit/test_ai_router.py +15 -15
- tests/unit/test_client.py +125 -123
- tests/unit/test_config.py +10 -10
- tests/unit/test_orchestration.py +2 -2
- tests/unit/test_parsing.py +1 -1
- tests/unit/test_types.py +1 -1
.env.example
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
|
| 4 |
|
| 5 |
# Grading Configuration
|
| 6 |
MAX_TOKENS=2048
|
|
|
|
| 1 |
+
# HuggingFace API Configuration
|
| 2 |
+
HUGGINGFACE_API_KEY=your_api_key_here
|
| 3 |
+
HF_MODEL_PRIMARY=meta-llama/Llama-2-70b-chat-hf
|
| 4 |
|
| 5 |
# Grading Configuration
|
| 6 |
MAX_TOKENS=2048
|
ai_router/__init__.py
CHANGED
|
@@ -30,7 +30,7 @@ from .client import generate_grading
|
|
| 30 |
from .orchestration import generate_batch_grading
|
| 31 |
|
| 32 |
# Re-export underscore-prefixed internals used by tests for compatibility
|
| 33 |
-
from .client import
|
| 34 |
from .orchestration import (
|
| 35 |
_build_grading_error_result,
|
| 36 |
_build_grading_success_result,
|
|
@@ -49,7 +49,7 @@ __all__ = [
|
|
| 49 |
# OCR sub-module (access as ai_router.ocr)
|
| 50 |
"ocr",
|
| 51 |
# Underscore-prefixed internals (test compatibility)
|
| 52 |
-
"
|
| 53 |
"_build_grading_success_result",
|
| 54 |
"_build_grading_error_result",
|
| 55 |
"_calculate_batch_stats",
|
|
|
|
| 30 |
from .orchestration import generate_batch_grading
|
| 31 |
|
| 32 |
# Re-export underscore-prefixed internals used by tests for compatibility
|
| 33 |
+
from .client import _get_huggingface_config
|
| 34 |
from .orchestration import (
|
| 35 |
_build_grading_error_result,
|
| 36 |
_build_grading_success_result,
|
|
|
|
| 49 |
# OCR sub-module (access as ai_router.ocr)
|
| 50 |
"ocr",
|
| 51 |
# Underscore-prefixed internals (test compatibility)
|
| 52 |
+
"_get_huggingface_config",
|
| 53 |
"_build_grading_success_result",
|
| 54 |
"_build_grading_error_result",
|
| 55 |
"_calculate_batch_stats",
|
ai_router/client.py
CHANGED
|
@@ -34,10 +34,8 @@ BACKOFF_MULTIPLIER = 2.0
|
|
| 34 |
HF_API_URL = "https://api-inference.huggingface.co/models"
|
| 35 |
|
| 36 |
|
| 37 |
-
def
|
| 38 |
"""Get HuggingFace API configuration from environment or defaults.
|
| 39 |
-
|
| 40 |
-
Legacy name kept for backwards compatibility with tests.
|
| 41 |
|
| 42 |
Returns:
|
| 43 |
Tuple of (api_key, model, max_tokens, temperature)
|
|
@@ -90,7 +88,7 @@ async def generate_grading(content: str, rubric: str) -> GradingResult:
|
|
| 90 |
"""
|
| 91 |
from .prompt import build_grading_prompt
|
| 92 |
|
| 93 |
-
api_key, model, max_tokens, temperature =
|
| 94 |
|
| 95 |
# Build the prompt
|
| 96 |
prompt = build_grading_prompt(content, rubric)
|
|
|
|
| 34 |
HF_API_URL = "https://api-inference.huggingface.co/models"
|
| 35 |
|
| 36 |
|
| 37 |
+
def _get_huggingface_config() -> tuple[str, str, int, float]:
|
| 38 |
"""Get HuggingFace API configuration from environment or defaults.
|
|
|
|
|
|
|
| 39 |
|
| 40 |
Returns:
|
| 41 |
Tuple of (api_key, model, max_tokens, temperature)
|
|
|
|
| 88 |
"""
|
| 89 |
from .prompt import build_grading_prompt
|
| 90 |
|
| 91 |
+
api_key, model, max_tokens, temperature = _get_huggingface_config()
|
| 92 |
|
| 93 |
# Build the prompt
|
| 94 |
prompt = build_grading_prompt(content, rubric)
|
ai_router/parsing.py
CHANGED
|
@@ -122,7 +122,7 @@ def _validate_grading_result(result: dict) -> GradingResult:
|
|
| 122 |
"strengths": strengths,
|
| 123 |
"improvements": improvements,
|
| 124 |
"feedback": str(result.get("feedback", "")),
|
| 125 |
-
"details": f"Graded using {config.
|
| 126 |
}
|
| 127 |
|
| 128 |
|
|
|
|
| 122 |
"strengths": strengths,
|
| 123 |
"improvements": improvements,
|
| 124 |
"feedback": str(result.get("feedback", "")),
|
| 125 |
+
"details": f"Graded using {config.HF_MODEL_DEFAULT}",
|
| 126 |
}
|
| 127 |
|
| 128 |
|
app.py
CHANGED
|
@@ -20,7 +20,6 @@ Just upload, review the rubric, and let GradeM8 do the rest!
|
|
| 20 |
from __future__ import annotations
|
| 21 |
|
| 22 |
import asyncio
|
| 23 |
-
import json
|
| 24 |
import logging
|
| 25 |
from pathlib import Path
|
| 26 |
from typing import TYPE_CHECKING, Any
|
|
@@ -33,7 +32,6 @@ import config
|
|
| 33 |
import document
|
| 34 |
from exceptions import DocumentError, DocumentConversionError, UnsupportedFileTypeError
|
| 35 |
from ui.components import (
|
| 36 |
-
get_accessibility_css,
|
| 37 |
create_results_cards,
|
| 38 |
create_error_html,
|
| 39 |
create_waiting_html,
|
|
@@ -53,6 +51,9 @@ logger = logging.getLogger(__name__)
|
|
| 53 |
# Empty DataFrame template for consistent empty results
|
| 54 |
_EMPTY_RESULTS_DF = pd.DataFrame(columns=["Student", "Score", "Status"])
|
| 55 |
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# =============================================================================
|
| 58 |
# Built-in Rubric Templates
|
|
@@ -153,8 +154,6 @@ def get_rubric_content(rubric_key: str) -> str:
|
|
| 153 |
async def process_batch_submissions(
|
| 154 |
rubric: str,
|
| 155 |
file_objs: list[Any],
|
| 156 |
-
text_size: str = "standard",
|
| 157 |
-
high_contrast: bool = False,
|
| 158 |
progress: gr.Progress = gr.Progress(), # noqa: B008
|
| 159 |
) -> tuple[str, str, str, bytes, pd.DataFrame, str]:
|
| 160 |
"""
|
|
@@ -247,7 +246,7 @@ def _process_extraction_results(
|
|
| 247 |
|
| 248 |
for result in extract_results:
|
| 249 |
if isinstance(result, Exception):
|
| 250 |
-
logger.error(
|
| 251 |
continue
|
| 252 |
filename, content, status = result
|
| 253 |
if status == "success" and content:
|
|
@@ -324,8 +323,12 @@ async def _extract_single(file_obj: Any, index: int) -> tuple[str, str, str]:
|
|
| 324 |
# Get the actual file path and read content
|
| 325 |
file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
# Extract text using format-specific handlers
|
| 331 |
text_content = document.extract_text(content_bytes, filename)
|
|
@@ -340,10 +343,10 @@ async def _extract_single(file_obj: Any, index: int) -> tuple[str, str, str]:
|
|
| 340 |
return filename, text_content, "success"
|
| 341 |
|
| 342 |
except DocumentError as e:
|
| 343 |
-
logger.error(
|
| 344 |
return filename, "", f"Could not read file: {e}"
|
| 345 |
-
except
|
| 346 |
-
logger.error(
|
| 347 |
return filename, "", "Unexpected error. Please try a different file format."
|
| 348 |
|
| 349 |
|
|
@@ -376,8 +379,8 @@ def _generate_word_report(batch_result: dict[str, Any], rubric: str) -> bytes:
|
|
| 376 |
results=batch_result["results"],
|
| 377 |
rubric=rubric,
|
| 378 |
)
|
| 379 |
-
except
|
| 380 |
-
logger.error(
|
| 381 |
return b""
|
| 382 |
|
| 383 |
|
|
@@ -412,13 +415,13 @@ def process_document_conversion(
|
|
| 412 |
return converted_bytes, output_filename, f"✓ Successfully converted to {output_format}"
|
| 413 |
|
| 414 |
except UnsupportedFileTypeError as e:
|
| 415 |
-
logger.error(
|
| 416 |
return None, "", f"❌ {e}"
|
| 417 |
except DocumentConversionError as e:
|
| 418 |
-
logger.error(
|
| 419 |
return None, "", f"❌ Conversion failed: {e}"
|
| 420 |
-
except
|
| 421 |
-
logger.error(
|
| 422 |
return None, "", "❌ An unexpected error occurred."
|
| 423 |
|
| 424 |
|
|
@@ -448,8 +451,8 @@ def process_images_to_pdf(
|
|
| 448 |
|
| 449 |
return pdf_bytes, "combined_images.pdf", f"✓ Combined {len(file_objs)} images into PDF"
|
| 450 |
|
| 451 |
-
except
|
| 452 |
-
logger.error(
|
| 453 |
return None, "", f"❌ Could not create PDF: {e}"
|
| 454 |
|
| 455 |
|
|
@@ -476,7 +479,7 @@ def create_interface() -> gr.Blocks:
|
|
| 476 |
""")
|
| 477 |
|
| 478 |
with gr.Row():
|
| 479 |
-
|
| 480 |
label="Text Size",
|
| 481 |
choices=[
|
| 482 |
("Standard", "standard"),
|
|
@@ -487,17 +490,13 @@ def create_interface() -> gr.Blocks:
|
|
| 487 |
elem_id="text-size-setting",
|
| 488 |
)
|
| 489 |
|
| 490 |
-
|
| 491 |
label="High Contrast Mode",
|
| 492 |
value=False,
|
| 493 |
elem_id="high-contrast-setting",
|
| 494 |
info="Increases contrast for better visibility",
|
| 495 |
)
|
| 496 |
|
| 497 |
-
# Apply accessibility settings
|
| 498 |
-
def apply_accessibility_settings(text_size: str, high_contrast: bool) -> str:
|
| 499 |
-
return get_accessibility_css(text_size, high_contrast)
|
| 500 |
-
|
| 501 |
# Header with better structure
|
| 502 |
gr.Markdown("""
|
| 503 |
<header role="banner">
|
|
@@ -556,7 +555,7 @@ def create_interface() -> gr.Blocks:
|
|
| 556 |
|
| 557 |
file_input = gr.File(
|
| 558 |
label="Upload papers (PDF, Word, or images)",
|
| 559 |
-
file_types=
|
| 560 |
file_count="multiple",
|
| 561 |
elem_id="file-upload",
|
| 562 |
)
|
|
@@ -740,7 +739,7 @@ def create_interface() -> gr.Blocks:
|
|
| 740 |
with gr.Column():
|
| 741 |
convert_file_input = gr.File(
|
| 742 |
label="Upload document to convert",
|
| 743 |
-
file_types=
|
| 744 |
file_count="single",
|
| 745 |
)
|
| 746 |
|
|
|
|
| 20 |
from __future__ import annotations
|
| 21 |
|
| 22 |
import asyncio
|
|
|
|
| 23 |
import logging
|
| 24 |
from pathlib import Path
|
| 25 |
from typing import TYPE_CHECKING, Any
|
|
|
|
| 32 |
import document
|
| 33 |
from exceptions import DocumentError, DocumentConversionError, UnsupportedFileTypeError
|
| 34 |
from ui.components import (
|
|
|
|
| 35 |
create_results_cards,
|
| 36 |
create_error_html,
|
| 37 |
create_waiting_html,
|
|
|
|
| 51 |
# Empty DataFrame template for consistent empty results
|
| 52 |
_EMPTY_RESULTS_DF = pd.DataFrame(columns=["Student", "Score", "Status"])
|
| 53 |
|
| 54 |
+
# Supported file types for grading and conversion
|
| 55 |
+
SUPPORTED_FILE_TYPES = [".pdf", ".docx", ".doc", ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"]
|
| 56 |
+
|
| 57 |
|
| 58 |
# =============================================================================
|
| 59 |
# Built-in Rubric Templates
|
|
|
|
| 154 |
async def process_batch_submissions(
|
| 155 |
rubric: str,
|
| 156 |
file_objs: list[Any],
|
|
|
|
|
|
|
| 157 |
progress: gr.Progress = gr.Progress(), # noqa: B008
|
| 158 |
) -> tuple[str, str, str, bytes, pd.DataFrame, str]:
|
| 159 |
"""
|
|
|
|
| 246 |
|
| 247 |
for result in extract_results:
|
| 248 |
if isinstance(result, Exception):
|
| 249 |
+
logger.error("Could not read file: %s", result)
|
| 250 |
continue
|
| 251 |
filename, content, status = result
|
| 252 |
if status == "success" and content:
|
|
|
|
| 323 |
# Get the actual file path and read content
|
| 324 |
file_path = file_obj.name if hasattr(file_obj, "name") else file_obj
|
| 325 |
|
| 326 |
+
# Use asyncio.to_thread for sync file I/O in async context
|
| 327 |
+
def read_file_sync():
|
| 328 |
+
with open(file_path, "rb") as f:
|
| 329 |
+
return f.read()
|
| 330 |
+
|
| 331 |
+
content_bytes = await asyncio.to_thread(read_file_sync)
|
| 332 |
|
| 333 |
# Extract text using format-specific handlers
|
| 334 |
text_content = document.extract_text(content_bytes, filename)
|
|
|
|
| 343 |
return filename, text_content, "success"
|
| 344 |
|
| 345 |
except DocumentError as e:
|
| 346 |
+
logger.error("Document error for %s: %s", filename, e)
|
| 347 |
return filename, "", f"Could not read file: {e}"
|
| 348 |
+
except (OSError, IOError) as e:
|
| 349 |
+
logger.error("Failed to extract text from %s: %s", filename, e)
|
| 350 |
return filename, "", "Unexpected error. Please try a different file format."
|
| 351 |
|
| 352 |
|
|
|
|
| 379 |
results=batch_result["results"],
|
| 380 |
rubric=rubric,
|
| 381 |
)
|
| 382 |
+
except (OSError, IOError, ValueError) as e:
|
| 383 |
+
logger.error("Failed to create Word report: %s", e)
|
| 384 |
return b""
|
| 385 |
|
| 386 |
|
|
|
|
| 415 |
return converted_bytes, output_filename, f"✓ Successfully converted to {output_format}"
|
| 416 |
|
| 417 |
except UnsupportedFileTypeError as e:
|
| 418 |
+
logger.error("Unsupported conversion: %s", e)
|
| 419 |
return None, "", f"❌ {e}"
|
| 420 |
except DocumentConversionError as e:
|
| 421 |
+
logger.error("Conversion failed: %s", e)
|
| 422 |
return None, "", f"❌ Conversion failed: {e}"
|
| 423 |
+
except (OSError, IOError, ValueError) as e:
|
| 424 |
+
logger.error("Unexpected error during conversion: %s", e)
|
| 425 |
return None, "", "❌ An unexpected error occurred."
|
| 426 |
|
| 427 |
|
|
|
|
| 451 |
|
| 452 |
return pdf_bytes, "combined_images.pdf", f"✓ Combined {len(file_objs)} images into PDF"
|
| 453 |
|
| 454 |
+
except (OSError, IOError, ValueError) as e:
|
| 455 |
+
logger.error("Failed to convert images to PDF: %s", e)
|
| 456 |
return None, "", f"❌ Could not create PDF: {e}"
|
| 457 |
|
| 458 |
|
|
|
|
| 479 |
""")
|
| 480 |
|
| 481 |
with gr.Row():
|
| 482 |
+
_text_size_setting = gr.Radio( # noqa: F841 - Future accessibility feature
|
| 483 |
label="Text Size",
|
| 484 |
choices=[
|
| 485 |
("Standard", "standard"),
|
|
|
|
| 490 |
elem_id="text-size-setting",
|
| 491 |
)
|
| 492 |
|
| 493 |
+
_high_contrast_setting = gr.Checkbox( # noqa: F841 - Future accessibility feature
|
| 494 |
label="High Contrast Mode",
|
| 495 |
value=False,
|
| 496 |
elem_id="high-contrast-setting",
|
| 497 |
info="Increases contrast for better visibility",
|
| 498 |
)
|
| 499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
# Header with better structure
|
| 501 |
gr.Markdown("""
|
| 502 |
<header role="banner">
|
|
|
|
| 555 |
|
| 556 |
file_input = gr.File(
|
| 557 |
label="Upload papers (PDF, Word, or images)",
|
| 558 |
+
file_types=SUPPORTED_FILE_TYPES,
|
| 559 |
file_count="multiple",
|
| 560 |
elem_id="file-upload",
|
| 561 |
)
|
|
|
|
| 739 |
with gr.Column():
|
| 740 |
convert_file_input = gr.File(
|
| 741 |
label="Upload document to convert",
|
| 742 |
+
file_types=SUPPORTED_FILE_TYPES,
|
| 743 |
file_count="single",
|
| 744 |
)
|
| 745 |
|
config.py
CHANGED
|
@@ -46,6 +46,9 @@ OCR_TIMEOUT_SECONDS: float = 300.0
|
|
| 46 |
# HuggingFace Inference API endpoint URL
|
| 47 |
HF_API_URL: str = "https://api-inference.huggingface.co/models"
|
| 48 |
|
|
|
|
|
|
|
|
|
|
| 49 |
# =============================================================================
|
| 50 |
# Concurrency Settings
|
| 51 |
# =============================================================================
|
|
@@ -348,18 +351,18 @@ h3 {{
|
|
| 348 |
# =============================================================================
|
| 349 |
|
| 350 |
# Required environment variable:
|
| 351 |
-
# -
|
| 352 |
#
|
| 353 |
# Optional overrides:
|
| 354 |
-
# -
|
| 355 |
-
# -
|
| 356 |
-
# -
|
| 357 |
# - OCR_FALLBACK_ENABLED: Enable/disable OCR fallback (true/false)
|
| 358 |
ENV_VARS: dict[str, str] = {
|
| 359 |
-
"
|
| 360 |
-
"
|
| 361 |
-
"
|
| 362 |
-
"
|
| 363 |
"ocr_fallback_enabled": "OCR_FALLBACK_ENABLED",
|
| 364 |
}
|
| 365 |
|
|
|
|
| 46 |
# HuggingFace Inference API endpoint URL
|
| 47 |
HF_API_URL: str = "https://api-inference.huggingface.co/models"
|
| 48 |
|
| 49 |
+
# DeepInfra API URL (used for OCR with DeepSeek-OCR model)
|
| 50 |
+
DEEPINFRA_API_URL: str = "https://api.deepinfra.com/v1/openai/chat/completions"
|
| 51 |
+
|
| 52 |
# =============================================================================
|
| 53 |
# Concurrency Settings
|
| 54 |
# =============================================================================
|
|
|
|
| 351 |
# =============================================================================
|
| 352 |
|
| 353 |
# Required environment variable:
|
| 354 |
+
# - HUGGINGFACE_API_KEY: Get from https://huggingface.co/settings/tokens
|
| 355 |
#
|
| 356 |
# Optional overrides:
|
| 357 |
+
# - HF_MODEL_PRIMARY: Override the default model
|
| 358 |
+
# - HF_MAX_TOKENS: Override max tokens
|
| 359 |
+
# - HF_TEMPERATURE: Override temperature
|
| 360 |
# - OCR_FALLBACK_ENABLED: Enable/disable OCR fallback (true/false)
|
| 361 |
ENV_VARS: dict[str, str] = {
|
| 362 |
+
"huggingface_key": "HUGGINGFACE_API_KEY",
|
| 363 |
+
"hf_model": "HF_MODEL_PRIMARY",
|
| 364 |
+
"hf_max_tokens": "HF_MAX_TOKENS",
|
| 365 |
+
"hf_temperature": "HF_TEMPERATURE",
|
| 366 |
"ocr_fallback_enabled": "OCR_FALLBACK_ENABLED",
|
| 367 |
}
|
| 368 |
|
document/conversion.py
CHANGED
|
@@ -36,6 +36,7 @@ try:
|
|
| 36 |
DOCX_SUPPORT = True
|
| 37 |
except ImportError:
|
| 38 |
DOCX_SUPPORT = False
|
|
|
|
| 39 |
|
| 40 |
# Image support
|
| 41 |
try:
|
|
|
|
| 36 |
DOCX_SUPPORT = True
|
| 37 |
except ImportError:
|
| 38 |
DOCX_SUPPORT = False
|
| 39 |
+
Document = None # type: ignore
|
| 40 |
|
| 41 |
# Image support
|
| 42 |
try:
|
exceptions.py
CHANGED
|
@@ -88,48 +88,3 @@ class AIServiceError(GradingError):
|
|
| 88 |
|
| 89 |
class ConfigurationError(GradeM8Error):
|
| 90 |
"""Raised when there's a configuration issue."""
|
| 91 |
-
pass
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
# =============================================================================
|
| 95 |
-
# AI Grading Errors
|
| 96 |
-
# =============================================================================
|
| 97 |
-
|
| 98 |
-
class GradingError(GradeM8Error):
|
| 99 |
-
"""Base exception for AI grading errors."""
|
| 100 |
-
pass
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
class ResponseParseError(GradingError):
|
| 104 |
-
"""Raised when the AI response cannot be parsed as valid JSON."""
|
| 105 |
-
|
| 106 |
-
def __init__(self, message: str, raw_response: str | None = None) -> None:
|
| 107 |
-
self.raw_response = raw_response
|
| 108 |
-
super().__init__(message)
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
class InvalidResponseError(GradingError):
|
| 112 |
-
"""Raised when the AI response is valid JSON but missing required fields."""
|
| 113 |
-
pass
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
class APIKeyError(GradingError):
|
| 117 |
-
"""Raised when the API key is missing or invalid."""
|
| 118 |
-
pass
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
class AIServiceError(GradingError):
|
| 122 |
-
"""Raised when the AI service returns an error."""
|
| 123 |
-
|
| 124 |
-
def __init__(self, message: str, status_code: int | None = None) -> None:
|
| 125 |
-
self.status_code = status_code
|
| 126 |
-
super().__init__(message)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
# =============================================================================
|
| 130 |
-
# Configuration Errors
|
| 131 |
-
# =============================================================================
|
| 132 |
-
|
| 133 |
-
class ConfigurationError(GradeM8Error):
|
| 134 |
-
"""Raised when there's a configuration issue."""
|
| 135 |
-
pass
|
|
|
|
| 88 |
|
| 89 |
class ConfigurationError(GradeM8Error):
|
| 90 |
"""Raised when there's a configuration issue."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/conftest.py
CHANGED
|
@@ -93,14 +93,14 @@ from unittest.mock import Mock, AsyncMock
|
|
| 93 |
@pytest.fixture(autouse=True)
|
| 94 |
def setup_test_env() -> Generator[None, None, None]:
|
| 95 |
"""Set up test environment with mock API key."""
|
| 96 |
-
old_key = os.environ.get("
|
| 97 |
-
os.environ["
|
| 98 |
yield
|
| 99 |
# Cleanup
|
| 100 |
if old_key is not None:
|
| 101 |
-
os.environ["
|
| 102 |
-
elif "
|
| 103 |
-
del os.environ["
|
| 104 |
|
| 105 |
|
| 106 |
@pytest.fixture
|
|
|
|
| 93 |
@pytest.fixture(autouse=True)
|
| 94 |
def setup_test_env() -> Generator[None, None, None]:
|
| 95 |
"""Set up test environment with mock API key."""
|
| 96 |
+
old_key = os.environ.get("HUGGINGFACE_API_KEY")
|
| 97 |
+
os.environ["HUGGINGFACE_API_KEY"] = "test_api_key_12345"
|
| 98 |
yield
|
| 99 |
# Cleanup
|
| 100 |
if old_key is not None:
|
| 101 |
+
os.environ["HUGGINGFACE_API_KEY"] = old_key
|
| 102 |
+
elif "HUGGINGFACE_API_KEY" in os.environ:
|
| 103 |
+
del os.environ["HUGGINGFACE_API_KEY"]
|
| 104 |
|
| 105 |
|
| 106 |
@pytest.fixture
|
tests/unit/test_ai_router.py
CHANGED
|
@@ -146,16 +146,16 @@ class TestTransformToFinalFormat:
|
|
| 146 |
assert results[0]["score"] == 80
|
| 147 |
|
| 148 |
|
| 149 |
-
class
|
| 150 |
-
"""Test
|
| 151 |
|
| 152 |
def test_uses_environment_variables(self, monkeypatch):
|
| 153 |
-
monkeypatch.setenv("
|
| 154 |
-
monkeypatch.setenv("
|
| 155 |
-
monkeypatch.setenv("
|
| 156 |
-
monkeypatch.setenv("
|
| 157 |
|
| 158 |
-
api_key, model, max_tokens, temperature = ai_router.
|
| 159 |
|
| 160 |
assert api_key == "test_key"
|
| 161 |
assert model == "test_model"
|
|
@@ -163,22 +163,22 @@ class TestGetDeepinfraConfig:
|
|
| 163 |
assert temperature == 0.5
|
| 164 |
|
| 165 |
def test_uses_defaults_for_optional_vars(self, monkeypatch):
|
| 166 |
-
monkeypatch.setenv("
|
| 167 |
# Unset optional vars
|
| 168 |
-
monkeypatch.delenv("
|
| 169 |
-
monkeypatch.delenv("
|
| 170 |
-
monkeypatch.delenv("
|
| 171 |
|
| 172 |
-
api_key, model, max_tokens, temperature = ai_router.
|
| 173 |
|
| 174 |
assert api_key == "test_key"
|
| 175 |
-
assert model == ai_router.config.
|
| 176 |
assert max_tokens == ai_router.config.MAX_TOKENS
|
| 177 |
assert temperature == ai_router.config.TEMPERATURE
|
| 178 |
|
| 179 |
def test_raises_when_api_key_missing(self, monkeypatch):
|
| 180 |
-
monkeypatch.delenv("
|
| 181 |
|
| 182 |
from exceptions import APIKeyError
|
| 183 |
with pytest.raises(APIKeyError):
|
| 184 |
-
ai_router.
|
|
|
|
| 146 |
assert results[0]["score"] == 80
|
| 147 |
|
| 148 |
|
| 149 |
+
class TestGetHuggingfaceConfig:
|
| 150 |
+
"""Test _get_huggingface_config helper."""
|
| 151 |
|
| 152 |
def test_uses_environment_variables(self, monkeypatch):
|
| 153 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 154 |
+
monkeypatch.setenv("HF_MODEL_PRIMARY", "test_model")
|
| 155 |
+
monkeypatch.setenv("HF_MAX_TOKENS", "1000")
|
| 156 |
+
monkeypatch.setenv("HF_TEMPERATURE", "0.5")
|
| 157 |
|
| 158 |
+
api_key, model, max_tokens, temperature = ai_router._get_huggingface_config()
|
| 159 |
|
| 160 |
assert api_key == "test_key"
|
| 161 |
assert model == "test_model"
|
|
|
|
| 163 |
assert temperature == 0.5
|
| 164 |
|
| 165 |
def test_uses_defaults_for_optional_vars(self, monkeypatch):
|
| 166 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 167 |
# Unset optional vars
|
| 168 |
+
monkeypatch.delenv("HF_MODEL_PRIMARY", raising=False)
|
| 169 |
+
monkeypatch.delenv("HF_MAX_TOKENS", raising=False)
|
| 170 |
+
monkeypatch.delenv("HF_TEMPERATURE", raising=False)
|
| 171 |
|
| 172 |
+
api_key, model, max_tokens, temperature = ai_router._get_huggingface_config()
|
| 173 |
|
| 174 |
assert api_key == "test_key"
|
| 175 |
+
assert model == ai_router.config.HF_MODEL_DEFAULT
|
| 176 |
assert max_tokens == ai_router.config.MAX_TOKENS
|
| 177 |
assert temperature == ai_router.config.TEMPERATURE
|
| 178 |
|
| 179 |
def test_raises_when_api_key_missing(self, monkeypatch):
|
| 180 |
+
monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
|
| 181 |
|
| 182 |
from exceptions import APIKeyError
|
| 183 |
with pytest.raises(APIKeyError):
|
| 184 |
+
ai_router._get_huggingface_config()
|
tests/unit/test_client.py
CHANGED
|
@@ -9,29 +9,29 @@ import pytest
|
|
| 9 |
import respx
|
| 10 |
from httpx import Response
|
| 11 |
|
| 12 |
-
from ai_router.client import
|
| 13 |
from exceptions import APIKeyError, AIServiceError, ResponseParseError
|
| 14 |
import config
|
| 15 |
|
| 16 |
|
| 17 |
-
class
|
| 18 |
-
"""Tests for
|
| 19 |
|
| 20 |
def test_raises_when_api_key_missing(self, monkeypatch):
|
| 21 |
-
"""Test error when
|
| 22 |
-
monkeypatch.delenv("
|
| 23 |
|
| 24 |
-
with pytest.raises(APIKeyError, match="
|
| 25 |
-
|
| 26 |
|
| 27 |
def test_uses_environment_variables(self, monkeypatch):
|
| 28 |
"""Test using environment variables."""
|
| 29 |
-
monkeypatch.setenv("
|
| 30 |
-
monkeypatch.setenv("
|
| 31 |
-
monkeypatch.setenv("
|
| 32 |
-
monkeypatch.setenv("
|
| 33 |
|
| 34 |
-
api_key, model, max_tokens, temperature =
|
| 35 |
|
| 36 |
assert api_key == "test_key"
|
| 37 |
assert model == "test_model"
|
|
@@ -40,24 +40,24 @@ class TestGetDeepinfraConfig:
|
|
| 40 |
|
| 41 |
def test_uses_defaults_for_optional_vars(self, monkeypatch):
|
| 42 |
"""Test using defaults when optional vars not set."""
|
| 43 |
-
monkeypatch.setenv("
|
| 44 |
-
monkeypatch.delenv("
|
| 45 |
-
monkeypatch.delenv("
|
| 46 |
-
monkeypatch.delenv("
|
| 47 |
|
| 48 |
-
api_key, model, max_tokens, temperature =
|
| 49 |
|
| 50 |
assert api_key == "test_key"
|
| 51 |
-
assert model == config.
|
| 52 |
assert max_tokens == config.MAX_TOKENS
|
| 53 |
assert temperature == config.TEMPERATURE
|
| 54 |
|
| 55 |
def test_empty_api_key_raises_error(self, monkeypatch):
|
| 56 |
"""Test error when API key is empty string."""
|
| 57 |
-
monkeypatch.setenv("
|
| 58 |
|
| 59 |
with pytest.raises(APIKeyError):
|
| 60 |
-
|
| 61 |
|
| 62 |
|
| 63 |
class TestGenerateGrading:
|
|
@@ -67,7 +67,7 @@ class TestGenerateGrading:
|
|
| 67 |
@pytest.mark.asyncio
|
| 68 |
async def test_successful_grading(self, monkeypatch):
|
| 69 |
"""Test successful grading API call."""
|
| 70 |
-
monkeypatch.setenv("
|
| 71 |
|
| 72 |
# Mock API response with valid grading JSON
|
| 73 |
grading_response = {
|
|
@@ -79,10 +79,9 @@ class TestGenerateGrading:
|
|
| 79 |
"feedback": "Well done!"
|
| 80 |
}
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
})
|
| 86 |
)
|
| 87 |
|
| 88 |
result = await generate_grading("Student submission", "Grade on clarity")
|
|
@@ -93,9 +92,9 @@ class TestGenerateGrading:
|
|
| 93 |
|
| 94 |
@respx.mock
|
| 95 |
@pytest.mark.asyncio
|
| 96 |
-
async def
|
| 97 |
-
"""Test grading response
|
| 98 |
-
monkeypatch.setenv("
|
| 99 |
|
| 100 |
grading_response = {
|
| 101 |
"score": 90,
|
|
@@ -106,12 +105,11 @@ class TestGenerateGrading:
|
|
| 106 |
"feedback": "Great job!"
|
| 107 |
}
|
| 108 |
|
| 109 |
-
|
|
|
|
| 110 |
|
| 111 |
-
route = respx.post(config.
|
| 112 |
-
return_value=Response(200, json={
|
| 113 |
-
"choices": [{"message": {"content": content}}]
|
| 114 |
-
})
|
| 115 |
)
|
| 116 |
|
| 117 |
result = await generate_grading("Submission", "Rubric")
|
|
@@ -121,7 +119,7 @@ class TestGenerateGrading:
|
|
| 121 |
@pytest.mark.asyncio
|
| 122 |
async def test_missing_api_key_raises_error(self, monkeypatch):
|
| 123 |
"""Test error when API key is missing."""
|
| 124 |
-
monkeypatch.delenv("
|
| 125 |
|
| 126 |
with pytest.raises(APIKeyError):
|
| 127 |
await generate_grading("Submission", "Rubric")
|
|
@@ -130,9 +128,9 @@ class TestGenerateGrading:
|
|
| 130 |
@pytest.mark.asyncio
|
| 131 |
async def test_http_401_error(self, monkeypatch):
|
| 132 |
"""Test 401 unauthorized error."""
|
| 133 |
-
monkeypatch.setenv("
|
| 134 |
|
| 135 |
-
route = respx.post(config.
|
| 136 |
return_value=Response(401, text="Unauthorized")
|
| 137 |
)
|
| 138 |
|
|
@@ -140,15 +138,15 @@ class TestGenerateGrading:
|
|
| 140 |
await generate_grading("Submission", "Rubric")
|
| 141 |
|
| 142 |
assert exc_info.value.status_code == 401
|
| 143 |
-
assert "
|
| 144 |
|
| 145 |
@respx.mock
|
| 146 |
@pytest.mark.asyncio
|
| 147 |
async def test_http_429_rate_limit(self, monkeypatch):
|
| 148 |
"""Test 429 rate limit error."""
|
| 149 |
-
monkeypatch.setenv("
|
| 150 |
|
| 151 |
-
route = respx.post(config.
|
| 152 |
return_value=Response(429, text="Rate limited")
|
| 153 |
)
|
| 154 |
|
|
@@ -161,9 +159,9 @@ class TestGenerateGrading:
|
|
| 161 |
@pytest.mark.asyncio
|
| 162 |
async def test_http_500_error(self, monkeypatch):
|
| 163 |
"""Test 500 server error."""
|
| 164 |
-
monkeypatch.setenv("
|
| 165 |
|
| 166 |
-
route = respx.post(config.
|
| 167 |
return_value=Response(500, text="Internal Server Error")
|
| 168 |
)
|
| 169 |
|
|
@@ -172,69 +170,68 @@ class TestGenerateGrading:
|
|
| 172 |
|
| 173 |
assert exc_info.value.status_code == 500
|
| 174 |
|
| 175 |
-
|
| 176 |
@respx.mock
|
| 177 |
@pytest.mark.asyncio
|
| 178 |
-
async def
|
| 179 |
-
"""Test
|
| 180 |
-
monkeypatch.setenv("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
})
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
with pytest.raises(ResponseParseError):
|
| 189 |
-
await generate_grading("Submission", "Rubric")
|
| 190 |
|
| 191 |
@respx.mock
|
| 192 |
@pytest.mark.asyncio
|
| 193 |
-
async def
|
| 194 |
-
"""Test
|
| 195 |
-
monkeypatch.setenv("
|
| 196 |
|
| 197 |
-
route = respx.post(config.
|
| 198 |
-
return_value=Response(200, json={"
|
| 199 |
)
|
| 200 |
|
| 201 |
-
with pytest.raises(
|
| 202 |
await generate_grading("Submission", "Rubric")
|
| 203 |
|
| 204 |
@respx.mock
|
| 205 |
@pytest.mark.asyncio
|
| 206 |
-
async def
|
| 207 |
-
"""Test response with empty
|
| 208 |
-
monkeypatch.setenv("
|
| 209 |
|
| 210 |
-
route = respx.post(config.
|
| 211 |
-
return_value=Response(200, json=
|
| 212 |
)
|
| 213 |
|
| 214 |
-
with pytest.raises(
|
| 215 |
-
await generate_grading("Submission", "Rubric")
|
| 216 |
-
|
| 217 |
-
@respx.mock
|
| 218 |
-
@pytest.mark.asyncio
|
| 219 |
-
async def test_missing_message_content(self, monkeypatch):
|
| 220 |
-
"""Test response missing message content."""
|
| 221 |
-
monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
|
| 222 |
-
|
| 223 |
-
route = respx.post(config.DEEPINFRA_API_URL).mock(
|
| 224 |
-
return_value=Response(200, json={
|
| 225 |
-
"choices": [{"message": {}}]
|
| 226 |
-
})
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
with pytest.raises(AIServiceError, match="Invalid API response"):
|
| 230 |
await generate_grading("Submission", "Rubric")
|
| 231 |
|
| 232 |
@respx.mock
|
| 233 |
@pytest.mark.asyncio
|
| 234 |
async def test_details_field_set(self, monkeypatch):
|
| 235 |
"""Test details field includes model info."""
|
| 236 |
-
monkeypatch.setenv("
|
| 237 |
-
monkeypatch.setenv("
|
| 238 |
|
| 239 |
grading_response = {
|
| 240 |
"score": 75,
|
|
@@ -245,67 +242,51 @@ class TestGenerateGrading:
|
|
| 245 |
"feedback": "Test feedback"
|
| 246 |
}
|
| 247 |
|
| 248 |
-
route = respx.post(
|
| 249 |
-
return_value=Response(200, json={
|
| 250 |
-
"choices": [{"message": {"content": json.dumps(grading_response)}}]
|
| 251 |
-
})
|
| 252 |
)
|
| 253 |
|
| 254 |
result = await generate_grading("Submission", "Rubric")
|
| 255 |
assert "custom-model" in result["details"]
|
| 256 |
-
assert "
|
| 257 |
-
|
| 258 |
-
@respx.mock
|
| 259 |
-
@pytest.mark.asyncio
|
| 260 |
-
async def test_response_format_is_json_object(self, monkeypatch):
|
| 261 |
-
"""Test that response_format is set to json_object."""
|
| 262 |
-
monkeypatch.setenv("DEEPINFRA_API_KEY", "test_key")
|
| 263 |
-
|
| 264 |
-
captured_payload = {}
|
| 265 |
-
|
| 266 |
-
def capture_request(request):
|
| 267 |
-
captured_payload.update(json.loads(request.content))
|
| 268 |
-
return Response(200, json={
|
| 269 |
-
"choices": [{"message": {"content": '{"score": 80}'}}]
|
| 270 |
-
})
|
| 271 |
-
|
| 272 |
-
route = respx.post(config.DEEPINFRA_API_URL).mock(side_effect=capture_request)
|
| 273 |
-
|
| 274 |
-
await generate_grading("Submission", "Rubric")
|
| 275 |
-
|
| 276 |
-
assert captured_payload["response_format"]["type"] == "json_object"
|
| 277 |
|
| 278 |
@respx.mock
|
| 279 |
@pytest.mark.asyncio
|
| 280 |
async def test_prompt_building(self, monkeypatch):
|
| 281 |
"""Test that prompt is built correctly."""
|
| 282 |
-
monkeypatch.setenv("
|
| 283 |
|
| 284 |
captured_payload = {}
|
| 285 |
|
| 286 |
def capture_request(request):
|
| 287 |
captured_payload.update(json.loads(request.content))
|
| 288 |
-
|
| 289 |
-
"
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
content = "Student essay content"
|
| 295 |
rubric = "Grade on clarity and grammar"
|
| 296 |
await generate_grading(content, rubric)
|
| 297 |
|
| 298 |
-
|
| 299 |
-
assert
|
| 300 |
-
assert
|
| 301 |
-
assert
|
| 302 |
-
assert
|
| 303 |
|
| 304 |
@respx.mock
|
| 305 |
@pytest.mark.asyncio
|
| 306 |
async def test_truncated_content_handling(self, monkeypatch):
|
| 307 |
"""Test handling of long content that needs truncation."""
|
| 308 |
-
monkeypatch.setenv("
|
| 309 |
|
| 310 |
long_content = "a" * (config.MAX_PROMPT_CHARS + 1000)
|
| 311 |
|
|
@@ -318,11 +299,32 @@ class TestGenerateGrading:
|
|
| 318 |
"feedback": "Test"
|
| 319 |
}
|
| 320 |
|
| 321 |
-
route = respx.post(config.
|
| 322 |
-
return_value=Response(200, json={
|
| 323 |
-
"choices": [{"message": {"content": json.dumps(grading_response)}}]
|
| 324 |
-
})
|
| 325 |
)
|
| 326 |
|
| 327 |
result = await generate_grading(long_content, "Rubric")
|
| 328 |
assert result["score"] == 80
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import respx
|
| 10 |
from httpx import Response
|
| 11 |
|
| 12 |
+
from ai_router.client import _get_huggingface_config, generate_grading, HF_API_URL
|
| 13 |
from exceptions import APIKeyError, AIServiceError, ResponseParseError
|
| 14 |
import config
|
| 15 |
|
| 16 |
|
| 17 |
+
class TestGetHuggingfaceConfig:
|
| 18 |
+
"""Tests for _get_huggingface_config function."""
|
| 19 |
|
| 20 |
def test_raises_when_api_key_missing(self, monkeypatch):
|
| 21 |
+
"""Test error when HUGGINGFACE_API_KEY is not set."""
|
| 22 |
+
monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
|
| 23 |
|
| 24 |
+
with pytest.raises(APIKeyError, match="HUGGINGFACE_API_KEY"):
|
| 25 |
+
_get_huggingface_config()
|
| 26 |
|
| 27 |
def test_uses_environment_variables(self, monkeypatch):
|
| 28 |
"""Test using environment variables."""
|
| 29 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 30 |
+
monkeypatch.setenv("HF_MODEL_PRIMARY", "test_model")
|
| 31 |
+
monkeypatch.setenv("HF_MAX_TOKENS", "1000")
|
| 32 |
+
monkeypatch.setenv("HF_TEMPERATURE", "0.5")
|
| 33 |
|
| 34 |
+
api_key, model, max_tokens, temperature = _get_huggingface_config()
|
| 35 |
|
| 36 |
assert api_key == "test_key"
|
| 37 |
assert model == "test_model"
|
|
|
|
| 40 |
|
| 41 |
def test_uses_defaults_for_optional_vars(self, monkeypatch):
|
| 42 |
"""Test using defaults when optional vars not set."""
|
| 43 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 44 |
+
monkeypatch.delenv("HF_MODEL_PRIMARY", raising=False)
|
| 45 |
+
monkeypatch.delenv("HF_MAX_TOKENS", raising=False)
|
| 46 |
+
monkeypatch.delenv("HF_TEMPERATURE", raising=False)
|
| 47 |
|
| 48 |
+
api_key, model, max_tokens, temperature = _get_huggingface_config()
|
| 49 |
|
| 50 |
assert api_key == "test_key"
|
| 51 |
+
assert model == config.HF_MODEL_DEFAULT
|
| 52 |
assert max_tokens == config.MAX_TOKENS
|
| 53 |
assert temperature == config.TEMPERATURE
|
| 54 |
|
| 55 |
def test_empty_api_key_raises_error(self, monkeypatch):
|
| 56 |
"""Test error when API key is empty string."""
|
| 57 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "")
|
| 58 |
|
| 59 |
with pytest.raises(APIKeyError):
|
| 60 |
+
_get_huggingface_config()
|
| 61 |
|
| 62 |
|
| 63 |
class TestGenerateGrading:
|
|
|
|
| 67 |
@pytest.mark.asyncio
|
| 68 |
async def test_successful_grading(self, monkeypatch):
|
| 69 |
"""Test successful grading API call."""
|
| 70 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 71 |
|
| 72 |
# Mock API response with valid grading JSON
|
| 73 |
grading_response = {
|
|
|
|
| 79 |
"feedback": "Well done!"
|
| 80 |
}
|
| 81 |
|
| 82 |
+
# HuggingFace returns a list with generated_text
|
| 83 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
|
| 84 |
+
return_value=Response(200, json=[{"generated_text": json.dumps(grading_response)}])
|
|
|
|
| 85 |
)
|
| 86 |
|
| 87 |
result = await generate_grading("Student submission", "Grade on clarity")
|
|
|
|
| 92 |
|
| 93 |
@respx.mock
|
| 94 |
@pytest.mark.asyncio
|
| 95 |
+
async def test_grading_with_embedded_json(self, monkeypatch):
|
| 96 |
+
"""Test grading response with JSON embedded in text."""
|
| 97 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 98 |
|
| 99 |
grading_response = {
|
| 100 |
"score": 90,
|
|
|
|
| 105 |
"feedback": "Great job!"
|
| 106 |
}
|
| 107 |
|
| 108 |
+
# Response with JSON embedded in text
|
| 109 |
+
content = f'Here is the grading: {json.dumps(grading_response)} End of response.'
|
| 110 |
|
| 111 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
|
| 112 |
+
return_value=Response(200, json=[{"generated_text": content}])
|
|
|
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
result = await generate_grading("Submission", "Rubric")
|
|
|
|
| 119 |
@pytest.mark.asyncio
|
| 120 |
async def test_missing_api_key_raises_error(self, monkeypatch):
|
| 121 |
"""Test error when API key is missing."""
|
| 122 |
+
monkeypatch.delenv("HUGGINGFACE_API_KEY", raising=False)
|
| 123 |
|
| 124 |
with pytest.raises(APIKeyError):
|
| 125 |
await generate_grading("Submission", "Rubric")
|
|
|
|
| 128 |
@pytest.mark.asyncio
|
| 129 |
async def test_http_401_error(self, monkeypatch):
|
| 130 |
"""Test 401 unauthorized error."""
|
| 131 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 132 |
|
| 133 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
|
| 134 |
return_value=Response(401, text="Unauthorized")
|
| 135 |
)
|
| 136 |
|
|
|
|
| 138 |
await generate_grading("Submission", "Rubric")
|
| 139 |
|
| 140 |
assert exc_info.value.status_code == 401
|
| 141 |
+
assert "HuggingFace API error" in str(exc_info.value)
|
| 142 |
|
| 143 |
@respx.mock
|
| 144 |
@pytest.mark.asyncio
|
| 145 |
async def test_http_429_rate_limit(self, monkeypatch):
|
| 146 |
"""Test 429 rate limit error."""
|
| 147 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 148 |
|
| 149 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
|
| 150 |
return_value=Response(429, text="Rate limited")
|
| 151 |
)
|
| 152 |
|
|
|
|
| 159 |
@pytest.mark.asyncio
|
| 160 |
async def test_http_500_error(self, monkeypatch):
|
| 161 |
"""Test 500 server error."""
|
| 162 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 163 |
|
| 164 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
|
| 165 |
return_value=Response(500, text="Internal Server Error")
|
| 166 |
)
|
| 167 |
|
|
|
|
| 170 |
|
| 171 |
assert exc_info.value.status_code == 500
|
| 172 |
|
|
|
|
| 173 |
@respx.mock
|
| 174 |
@pytest.mark.asyncio
|
| 175 |
+
async def test_http_503_model_loading(self, monkeypatch):
|
| 176 |
+
"""Test 503 model loading triggers retry."""
|
| 177 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 178 |
+
|
| 179 |
+
# 503 on first call, success on retry
|
| 180 |
+
call_count = 0
|
| 181 |
+
|
| 182 |
+
def mock_response(request):
|
| 183 |
+
nonlocal call_count
|
| 184 |
+
call_count += 1
|
| 185 |
+
if call_count == 1:
|
| 186 |
+
return Response(503, text="Model is loading")
|
| 187 |
+
grading_response = {
|
| 188 |
+
"score": 80,
|
| 189 |
+
"rubric_breakdown": {},
|
| 190 |
+
"summary": "Good",
|
| 191 |
+
"strengths": [],
|
| 192 |
+
"improvements": [],
|
| 193 |
+
"feedback": "Nice work"
|
| 194 |
+
}
|
| 195 |
+
return Response(200, json=[{"generated_text": json.dumps(grading_response)}])
|
| 196 |
+
|
| 197 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(side_effect=mock_response)
|
| 198 |
|
| 199 |
+
result = await generate_grading("Submission", "Rubric")
|
| 200 |
+
assert result["score"] == 80
|
| 201 |
+
assert call_count == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
@respx.mock
|
| 204 |
@pytest.mark.asyncio
|
| 205 |
+
async def test_invalid_json_response(self, monkeypatch):
|
| 206 |
+
"""Test invalid JSON in response."""
|
| 207 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 208 |
|
| 209 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
|
| 210 |
+
return_value=Response(200, json=[{"generated_text": "not valid json at all"}])
|
| 211 |
)
|
| 212 |
|
| 213 |
+
with pytest.raises(ResponseParseError):
|
| 214 |
await generate_grading("Submission", "Rubric")
|
| 215 |
|
| 216 |
@respx.mock
|
| 217 |
@pytest.mark.asyncio
|
| 218 |
+
async def test_empty_response_list(self, monkeypatch):
|
| 219 |
+
"""Test response with empty list."""
|
| 220 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 221 |
|
| 222 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
|
| 223 |
+
return_value=Response(200, json=[])
|
| 224 |
)
|
| 225 |
|
| 226 |
+
with pytest.raises(ResponseParseError):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
await generate_grading("Submission", "Rubric")
|
| 228 |
|
| 229 |
@respx.mock
|
| 230 |
@pytest.mark.asyncio
|
| 231 |
async def test_details_field_set(self, monkeypatch):
|
| 232 |
"""Test details field includes model info."""
|
| 233 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 234 |
+
monkeypatch.setenv("HF_MODEL_PRIMARY", "custom-model")
|
| 235 |
|
| 236 |
grading_response = {
|
| 237 |
"score": 75,
|
|
|
|
| 242 |
"feedback": "Test feedback"
|
| 243 |
}
|
| 244 |
|
| 245 |
+
route = respx.post(f"{HF_API_URL}/custom-model").mock(
|
| 246 |
+
return_value=Response(200, json=[{"generated_text": json.dumps(grading_response)}])
|
|
|
|
|
|
|
| 247 |
)
|
| 248 |
|
| 249 |
result = await generate_grading("Submission", "Rubric")
|
| 250 |
assert "custom-model" in result["details"]
|
| 251 |
+
assert "HuggingFace" in result["details"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
@respx.mock
|
| 254 |
@pytest.mark.asyncio
|
| 255 |
async def test_prompt_building(self, monkeypatch):
|
| 256 |
"""Test that prompt is built correctly."""
|
| 257 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 258 |
|
| 259 |
captured_payload = {}
|
| 260 |
|
| 261 |
def capture_request(request):
|
| 262 |
captured_payload.update(json.loads(request.content))
|
| 263 |
+
grading_response = {
|
| 264 |
+
"score": 80,
|
| 265 |
+
"rubric_breakdown": {},
|
| 266 |
+
"summary": "Good",
|
| 267 |
+
"strengths": [],
|
| 268 |
+
"improvements": [],
|
| 269 |
+
"feedback": "Nice"
|
| 270 |
+
}
|
| 271 |
+
return Response(200, json=[{"generated_text": json.dumps(grading_response)}])
|
| 272 |
+
|
| 273 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(side_effect=capture_request)
|
| 274 |
|
| 275 |
content = "Student essay content"
|
| 276 |
rubric = "Grade on clarity and grammar"
|
| 277 |
await generate_grading(content, rubric)
|
| 278 |
|
| 279 |
+
# HuggingFace format: inputs as full prompt, parameters for generation
|
| 280 |
+
assert "inputs" in captured_payload
|
| 281 |
+
assert content in captured_payload["inputs"]
|
| 282 |
+
assert rubric in captured_payload["inputs"]
|
| 283 |
+
assert "parameters" in captured_payload
|
| 284 |
|
| 285 |
@respx.mock
|
| 286 |
@pytest.mark.asyncio
|
| 287 |
async def test_truncated_content_handling(self, monkeypatch):
|
| 288 |
"""Test handling of long content that needs truncation."""
|
| 289 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 290 |
|
| 291 |
long_content = "a" * (config.MAX_PROMPT_CHARS + 1000)
|
| 292 |
|
|
|
|
| 299 |
"feedback": "Test"
|
| 300 |
}
|
| 301 |
|
| 302 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
|
| 303 |
+
return_value=Response(200, json=[{"generated_text": json.dumps(grading_response)}])
|
|
|
|
|
|
|
| 304 |
)
|
| 305 |
|
| 306 |
result = await generate_grading(long_content, "Rubric")
|
| 307 |
assert result["score"] == 80
|
| 308 |
+
|
| 309 |
+
@respx.mock
|
| 310 |
+
@pytest.mark.asyncio
|
| 311 |
+
async def test_dict_response_format(self, monkeypatch):
|
| 312 |
+
"""Test handling of dict response format (alternative HF response)."""
|
| 313 |
+
monkeypatch.setenv("HUGGINGFACE_API_KEY", "test_key")
|
| 314 |
+
|
| 315 |
+
grading_response = {
|
| 316 |
+
"score": 85,
|
| 317 |
+
"rubric_breakdown": {},
|
| 318 |
+
"summary": "Well done",
|
| 319 |
+
"strengths": [],
|
| 320 |
+
"improvements": [],
|
| 321 |
+
"feedback": "Good work"
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
# Some HF models return dict instead of list
|
| 325 |
+
route = respx.post(f"{HF_API_URL}/{config.HF_MODEL_DEFAULT}").mock(
|
| 326 |
+
return_value=Response(200, json={"generated_text": json.dumps(grading_response)})
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
result = await generate_grading("Submission", "Rubric")
|
| 330 |
+
assert result["score"] == 85
|
tests/unit/test_config.py
CHANGED
|
@@ -10,9 +10,9 @@ import config
|
|
| 10 |
class TestConfigConstants:
|
| 11 |
"""Test configuration constants have expected values."""
|
| 12 |
|
| 13 |
-
def
|
| 14 |
-
"""Test default
|
| 15 |
-
assert config.
|
| 16 |
|
| 17 |
def test_max_tokens(self):
|
| 18 |
"""Test max tokens setting."""
|
|
@@ -60,13 +60,13 @@ class TestEnvironmentVariables:
|
|
| 60 |
|
| 61 |
def test_env_vars_defined(self):
|
| 62 |
"""Test that environment variable names are defined."""
|
| 63 |
-
assert "
|
| 64 |
-
assert config.ENV_VARS["
|
| 65 |
|
| 66 |
@pytest.mark.parametrize("env_var,config_key", [
|
| 67 |
-
("
|
| 68 |
-
("
|
| 69 |
-
("
|
| 70 |
])
|
| 71 |
def test_env_var_mappings(self, env_var, config_key):
|
| 72 |
"""Test environment variable mappings."""
|
|
@@ -120,8 +120,8 @@ class TestConfigTypes:
|
|
| 120 |
def test_string_configs_are_strings(self):
|
| 121 |
"""Test string configuration values."""
|
| 122 |
string_configs = [
|
| 123 |
-
config.
|
| 124 |
-
config.
|
| 125 |
config.LOG_LEVEL,
|
| 126 |
config.LOG_FORMAT,
|
| 127 |
config.GRADING_SYSTEM_PROMPT,
|
|
|
|
| 10 |
class TestConfigConstants:
|
| 11 |
"""Test configuration constants have expected values."""
|
| 12 |
|
| 13 |
+
def test_hf_model_default(self):
|
| 14 |
+
"""Test default HuggingFace model."""
|
| 15 |
+
assert config.HF_MODEL_DEFAULT == "meta-llama/Llama-2-70b-chat-hf"
|
| 16 |
|
| 17 |
def test_max_tokens(self):
|
| 18 |
"""Test max tokens setting."""
|
|
|
|
| 60 |
|
| 61 |
def test_env_vars_defined(self):
|
| 62 |
"""Test that environment variable names are defined."""
|
| 63 |
+
assert "huggingface_key" in config.ENV_VARS
|
| 64 |
+
assert config.ENV_VARS["huggingface_key"] == "HUGGINGFACE_API_KEY"
|
| 65 |
|
| 66 |
@pytest.mark.parametrize("env_var,config_key", [
|
| 67 |
+
("HF_MODEL_PRIMARY", "hf_model"),
|
| 68 |
+
("HF_MAX_TOKENS", "hf_max_tokens"),
|
| 69 |
+
("HF_TEMPERATURE", "hf_temperature"),
|
| 70 |
])
|
| 71 |
def test_env_var_mappings(self, env_var, config_key):
|
| 72 |
"""Test environment variable mappings."""
|
|
|
|
| 120 |
def test_string_configs_are_strings(self):
|
| 121 |
"""Test string configuration values."""
|
| 122 |
string_configs = [
|
| 123 |
+
config.HF_MODEL_DEFAULT,
|
| 124 |
+
config.HF_API_URL,
|
| 125 |
config.LOG_LEVEL,
|
| 126 |
config.LOG_FORMAT,
|
| 127 |
config.GRADING_SYSTEM_PROMPT,
|
tests/unit/test_orchestration.py
CHANGED
|
@@ -180,7 +180,7 @@ class TestGenerateBatchGrading:
|
|
| 180 |
|
| 181 |
call_count = 0
|
| 182 |
|
| 183 |
-
async def mock_generate(
|
| 184 |
nonlocal call_count
|
| 185 |
call_count += 1
|
| 186 |
if call_count == 1:
|
|
@@ -208,7 +208,7 @@ class TestGenerateBatchGrading:
|
|
| 208 |
concurrent_calls = 0
|
| 209 |
max_concurrent = 0
|
| 210 |
|
| 211 |
-
async def mock_generate(
|
| 212 |
nonlocal concurrent_calls, max_concurrent
|
| 213 |
concurrent_calls += 1
|
| 214 |
max_concurrent = max(max_concurrent, concurrent_calls)
|
|
|
|
| 180 |
|
| 181 |
call_count = 0
|
| 182 |
|
| 183 |
+
async def mock_generate(content, rubric):
|
| 184 |
nonlocal call_count
|
| 185 |
call_count += 1
|
| 186 |
if call_count == 1:
|
|
|
|
| 208 |
concurrent_calls = 0
|
| 209 |
max_concurrent = 0
|
| 210 |
|
| 211 |
+
async def mock_generate(content, rubric):
|
| 212 |
nonlocal concurrent_calls, max_concurrent
|
| 213 |
concurrent_calls += 1
|
| 214 |
max_concurrent = max(max_concurrent, concurrent_calls)
|
tests/unit/test_parsing.py
CHANGED
|
@@ -254,7 +254,7 @@ class TestValidateGradingResult:
|
|
| 254 |
"""Test details field is set with model info."""
|
| 255 |
result = {"score": 80}
|
| 256 |
validated = _validate_grading_result(result)
|
| 257 |
-
assert config.
|
| 258 |
|
| 259 |
|
| 260 |
class TestParseGradingResponse:
|
|
|
|
| 254 |
"""Test details field is set with model info."""
|
| 255 |
result = {"score": 80}
|
| 256 |
validated = _validate_grading_result(result)
|
| 257 |
+
assert config.HF_MODEL_DEFAULT in validated["details"]
|
| 258 |
|
| 259 |
|
| 260 |
class TestParseGradingResponse:
|
tests/unit/test_types.py
CHANGED
|
@@ -26,7 +26,7 @@ class TestTypeDefinitions:
|
|
| 26 |
def test_grading_result_with_status_required_fields(self):
|
| 27 |
"""Test GradingResultWithStatus has all required fields."""
|
| 28 |
hints = get_type_hints(GradingResultWithStatus)
|
| 29 |
-
required_fields = {"index", "score", "feedback", "summary", "rubric_breakdown", "strengths", "improvements", "details", "status"}
|
| 30 |
|
| 31 |
assert set(hints.keys()) == required_fields
|
| 32 |
|
|
|
|
| 26 |
def test_grading_result_with_status_required_fields(self):
|
| 27 |
"""Test GradingResultWithStatus has all required fields."""
|
| 28 |
hints = get_type_hints(GradingResultWithStatus)
|
| 29 |
+
required_fields = {"index", "score", "feedback", "summary", "rubric_breakdown", "strengths", "improvements", "details", "status", "filename"}
|
| 30 |
|
| 31 |
assert set(hints.keys()) == required_fields
|
| 32 |
|