GradeM8 Deploy commited on
Commit
28b7952
·
1 Parent(s): 997fd5a

refactor: replace DeepInfra with HuggingFace-only backend

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. ai_router/__init__.py +3 -3
  3. ai_router/client.py +56 -35
  4. config.py +10 -10
.gitignore CHANGED
@@ -141,3 +141,6 @@ _types_backup.py
141
 
142
  # Gradio/Hugging Face
143
  gradio_cached_examples/
 
 
 
 
141
 
142
  # Gradio/Hugging Face
143
  gradio_cached_examples/
144
+
145
+ # Local secrets
146
+ .env
ai_router/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
  """
2
- AI routing package for grading homework using DeepInfra.
3
 
4
- This package provides AI grading functionality using DeepInfra as the sole provider,
5
  with detailed, structured feedback including rubric breakdowns, strengths, and
6
  actionable improvements. Also includes OCR capabilities for scanned documents.
7
 
8
  Modules:
9
- client: DeepInfra HTTP client and API configuration
10
  prompt: Prompt building functions for grading requests
11
  parsing: Response parsing and validation utilities
12
  orchestration: Batch grading with concurrent processing
 
1
  """
2
+ AI routing package for grading homework using HuggingFace Inference API.
3
 
4
+ This package provides AI grading functionality using HuggingFace as the provider,
5
  with detailed, structured feedback including rubric breakdowns, strengths, and
6
  actionable improvements. Also includes OCR capabilities for scanned documents.
7
 
8
  Modules:
9
+ client: HuggingFace HTTP client and API configuration
10
  prompt: Prompt building functions for grading requests
11
  parsing: Response parsing and validation utilities
12
  orchestration: Batch grading with concurrent processing
ai_router/client.py CHANGED
@@ -1,8 +1,8 @@
1
  """
2
- DeepInfra API client module for AI grading.
3
 
4
  This module provides HTTP client functionality for communicating with
5
- the DeepInfra API, including configuration management and request handling.
6
  """
7
 
8
  from __future__ import annotations
@@ -30,35 +30,40 @@ MAX_RETRIES = 3
30
  INITIAL_BACKOFF_SECONDS = 1.0
31
  BACKOFF_MULTIPLIER = 2.0
32
 
 
 
 
33
 
34
  def _get_deepinfra_config() -> tuple[str, str, int, float]:
35
- """Get DeepInfra API configuration from environment or defaults.
 
 
36
 
37
  Returns:
38
  Tuple of (api_key, model, max_tokens, temperature)
39
 
40
  Raises:
41
- APIKeyError: If DEEPINFRA_API_KEY is not set.
42
  """
43
- api_key = os.getenv("DEEPINFRA_API_KEY")
44
  if not api_key:
45
  raise APIKeyError(
46
- "DEEPINFRA_API_KEY environment variable is not set. "
47
- "Please set your DeepInfra API key to use the grading feature."
48
  )
49
 
50
- model = os.getenv("DEEPINFRA_MODEL", config.DEEPINFRA_MODEL_DEFAULT)
51
- max_tokens = int(os.getenv("DEEPINFRA_MAX_TOKENS", config.MAX_TOKENS))
52
- temperature = float(os.getenv("DEEPINFRA_TEMPERATURE", config.TEMPERATURE))
53
 
54
  return api_key, model, max_tokens, temperature
55
 
56
 
57
  async def generate_grading(content: str, rubric: str) -> GradingResult:
58
  """
59
- Generate grading feedback using DeepInfra API with automatic retry logic.
60
 
61
- This function sends the submission to DeepInfra for evaluation and
62
  returns a structured grading result with detailed feedback.
63
 
64
  Implements exponential backoff retry on transient failures (429, 5xx).
@@ -78,7 +83,7 @@ async def generate_grading(content: str, rubric: str) -> GradingResult:
78
  - details: Additional context
79
 
80
  Raises:
81
- APIKeyError: If DeepInfra API key is not configured
82
  AIServiceError: If API returns error status or times out
83
  ResponseParseError: If response cannot be parsed
84
  InvalidResponseError: If response is missing required fields
@@ -89,23 +94,26 @@ async def generate_grading(content: str, rubric: str) -> GradingResult:
89
 
90
  # Build the prompt
91
  prompt = build_grading_prompt(content, rubric)
 
 
 
92
 
93
- # Prepare the API request
94
  headers = {
95
  "Authorization": f"Bearer {api_key}",
96
  "Content-Type": "application/json",
97
  }
98
 
99
  payload = {
100
- "model": model,
101
- "messages": [
102
- {"role": "system", "content": config.GRADING_SYSTEM_PROMPT},
103
- {"role": "user", "content": prompt},
104
- ],
105
- "max_tokens": max_tokens,
106
- "temperature": temperature,
107
- "response_format": {"type": "json_object"},
108
  }
 
 
109
 
110
  # Implement retry logic with exponential backoff
111
  backoff_seconds = INITIAL_BACKOFF_SECONDS
@@ -114,7 +122,7 @@ async def generate_grading(content: str, rubric: str) -> GradingResult:
114
  try:
115
  async with httpx.AsyncClient(timeout=config.HTTP_TIMEOUT_SECONDS) as client:
116
  response = await client.post(
117
- config.DEEPINFRA_API_URL,
118
  headers=headers,
119
  json=payload,
120
  )
@@ -122,28 +130,29 @@ async def generate_grading(content: str, rubric: str) -> GradingResult:
122
 
123
  except httpx.HTTPStatusError as e:
124
  status_code = e.response.status_code
125
- logger.warning(f"DeepInfra API error (attempt {attempt + 1}/{MAX_RETRIES}): {status_code}")
126
 
127
- # Determine if error is retryable
128
  is_retryable = status_code in (429, 500, 502, 503, 504)
129
 
130
  if not is_retryable or attempt == MAX_RETRIES - 1:
131
- logger.error(f"DeepInfra API error: {status_code} - {e.response.text}")
132
  raise AIServiceError(
133
- f"DeepInfra API error: {status_code}",
134
  status_code,
135
  ) from e
136
 
137
- # Backoff before retry
138
- logger.info(f"Retrying in {backoff_seconds:.1f}s...")
139
- await asyncio.sleep(backoff_seconds)
 
140
  backoff_seconds *= BACKOFF_MULTIPLIER
141
  continue
142
 
143
  except httpx.RequestError as e:
144
- logger.error(f"DeepInfra request error (attempt {attempt + 1}/{MAX_RETRIES}): {e}")
145
  if attempt == MAX_RETRIES - 1:
146
- raise AIServiceError(f"Failed to connect to DeepInfra API: {e}") from e
147
  await asyncio.sleep(backoff_seconds)
148
  backoff_seconds *= BACKOFF_MULTIPLIER
149
  continue
@@ -151,14 +160,26 @@ async def generate_grading(content: str, rubric: str) -> GradingResult:
151
  # Successful response - parse it
152
  try:
153
  result = response.json()
154
- generated_text = result["choices"][0]["message"]["content"]
 
 
 
 
 
 
155
  except (KeyError, IndexError, json.JSONDecodeError) as e:
156
- logger.error(f"Unexpected API response format: {e}")
157
  raise AIServiceError(f"Invalid API response format: {e}") from e
158
 
 
 
 
 
 
 
159
  # Parse the grading response
160
  parsed_result = parse_grading_response(generated_text)
161
- parsed_result["details"] = f"Graded using {model} via DeepInfra"
162
 
163
  return parsed_result
164
 
 
1
  """
2
+ HuggingFace API client module for AI grading.
3
 
4
  This module provides HTTP client functionality for communicating with
5
+ the HuggingFace Inference API for AI-powered grading.
6
  """
7
 
8
  from __future__ import annotations
 
30
  INITIAL_BACKOFF_SECONDS = 1.0
31
  BACKOFF_MULTIPLIER = 2.0
32
 
33
+ # HuggingFace Inference API URL
34
+ HF_API_URL = "https://api-inference.huggingface.co/models"
35
+
36
 
37
  def _get_deepinfra_config() -> tuple[str, str, int, float]:
38
+ """Get HuggingFace API configuration from environment or defaults.
39
+
40
+ Legacy name kept for backwards compatibility with tests.
41
 
42
  Returns:
43
  Tuple of (api_key, model, max_tokens, temperature)
44
 
45
  Raises:
46
+ APIKeyError: If HUGGINGFACE_API_KEY is not set.
47
  """
48
+ api_key = os.getenv("HUGGINGFACE_API_KEY")
49
  if not api_key:
50
  raise APIKeyError(
51
+ "HUGGINGFACE_API_KEY environment variable is not set. "
52
+ "Please set your HuggingFace API key to use the grading feature."
53
  )
54
 
55
+ model = os.getenv("HF_MODEL_PRIMARY", config.HF_MODEL_DEFAULT)
56
+ max_tokens = int(os.getenv("HF_MAX_TOKENS", config.MAX_TOKENS))
57
+ temperature = float(os.getenv("HF_TEMPERATURE", config.TEMPERATURE))
58
 
59
  return api_key, model, max_tokens, temperature
60
 
61
 
62
  async def generate_grading(content: str, rubric: str) -> GradingResult:
63
  """
64
+ Generate grading feedback using HuggingFace Inference API with automatic retry logic.
65
 
66
+ This function sends the submission to HuggingFace for evaluation and
67
  returns a structured grading result with detailed feedback.
68
 
69
  Implements exponential backoff retry on transient failures (429, 5xx).
 
83
  - details: Additional context
84
 
85
  Raises:
86
+ APIKeyError: If HuggingFace API key is not configured
87
  AIServiceError: If API returns error status or times out
88
  ResponseParseError: If response cannot be parsed
89
  InvalidResponseError: If response is missing required fields
 
94
 
95
  # Build the prompt
96
  prompt = build_grading_prompt(content, rubric)
97
+
98
+ # Build full prompt with system message
99
+ full_prompt = f"{config.GRADING_SYSTEM_PROMPT}\n\n{prompt}"
100
 
101
+ # Prepare the API request for HuggingFace Inference API
102
  headers = {
103
  "Authorization": f"Bearer {api_key}",
104
  "Content-Type": "application/json",
105
  }
106
 
107
  payload = {
108
+ "inputs": full_prompt,
109
+ "parameters": {
110
+ "max_new_tokens": max_tokens,
111
+ "temperature": temperature,
112
+ "return_full_text": False,
113
+ },
 
 
114
  }
115
+
116
+ api_url = f"{HF_API_URL}/{model}"
117
 
118
  # Implement retry logic with exponential backoff
119
  backoff_seconds = INITIAL_BACKOFF_SECONDS
 
122
  try:
123
  async with httpx.AsyncClient(timeout=config.HTTP_TIMEOUT_SECONDS) as client:
124
  response = await client.post(
125
+ api_url,
126
  headers=headers,
127
  json=payload,
128
  )
 
130
 
131
  except httpx.HTTPStatusError as e:
132
  status_code = e.response.status_code
133
+ logger.warning("HuggingFace API error (attempt %d/%d): %s", attempt + 1, MAX_RETRIES, status_code)
134
 
135
+ # Determine if error is retryable (including 503 for model loading)
136
  is_retryable = status_code in (429, 500, 502, 503, 504)
137
 
138
  if not is_retryable or attempt == MAX_RETRIES - 1:
139
+ logger.error("HuggingFace API error: %s - %s", status_code, e.response.text)
140
  raise AIServiceError(
141
+ f"HuggingFace API error: {status_code}",
142
  status_code,
143
  ) from e
144
 
145
+ # Backoff before retry (longer for 503 model loading)
146
+ wait_time = backoff_seconds * 2 if status_code == 503 else backoff_seconds
147
+ logger.info("Retrying in %.1fs...", wait_time)
148
+ await asyncio.sleep(wait_time)
149
  backoff_seconds *= BACKOFF_MULTIPLIER
150
  continue
151
 
152
  except httpx.RequestError as e:
153
+ logger.error("HuggingFace request error (attempt %d/%d): %s", attempt + 1, MAX_RETRIES, e)
154
  if attempt == MAX_RETRIES - 1:
155
+ raise AIServiceError(f"Failed to connect to HuggingFace API: {e}") from e
156
  await asyncio.sleep(backoff_seconds)
157
  backoff_seconds *= BACKOFF_MULTIPLIER
158
  continue
 
160
  # Successful response - parse it
161
  try:
162
  result = response.json()
163
+ # HuggingFace returns a list with generated_text
164
+ if isinstance(result, list) and len(result) > 0:
165
+ generated_text = result[0].get("generated_text", "")
166
+ elif isinstance(result, dict):
167
+ generated_text = result.get("generated_text", result.get("text", ""))
168
+ else:
169
+ generated_text = str(result)
170
  except (KeyError, IndexError, json.JSONDecodeError) as e:
171
+ logger.error("Unexpected API response format: %s", e)
172
  raise AIServiceError(f"Invalid API response format: {e}") from e
173
 
174
+ # Extract JSON from response if wrapped in text
175
+ if "{" in generated_text and "}" in generated_text:
176
+ json_start = generated_text.find("{")
177
+ json_end = generated_text.rfind("}") + 1
178
+ generated_text = generated_text[json_start:json_end]
179
+
180
  # Parse the grading response
181
  parsed_result = parse_grading_response(generated_text)
182
+ parsed_result["details"] = f"Graded using {model} via HuggingFace"
183
 
184
  return parsed_result
185
 
config.py CHANGED
@@ -7,16 +7,16 @@ All values can be overridden via environment variables where applicable.
7
  """
8
 
9
  # =============================================================================
10
- # AI Model Configuration (DeepInfra Only)
11
  # =============================================================================
12
 
13
- # Default DeepInfra model to use for grading.
14
- # Using Llama 3.1 70B for high-quality grading with good instruction following.
15
  # Alternative options:
16
- # - "meta-llama/Meta-Llama-3.1-8B-Instruct" (faster, cheaper)
17
- # - "meta-llama/Meta-Llama-3.1-70B-Instruct" (better quality)
18
- # - "microsoft/WizardLM-2-8x22B" (excellent for long-form feedback)
19
- DEEPINFRA_MODEL_DEFAULT: str = "meta-llama/Meta-Llama-3.1-70B-Instruct"
20
 
21
  # DeepSeek OCR model for extracting text from images and scanned documents
22
  DEEPSK_OCR_MODEL: str = "deepseek-ai/DeepSeek-OCR"
@@ -36,15 +36,15 @@ TEMPERATURE: float = 0.2
36
  # HTTP Client Settings
37
  # =============================================================================
38
 
39
- # Timeout for HTTP requests to DeepInfra API (in seconds).
40
  # Set to 180s to accommodate larger models and detailed responses.
41
  HTTP_TIMEOUT_SECONDS: float = 180.0
42
 
43
  # Timeout for OCR requests (longer due to image processing)
44
  OCR_TIMEOUT_SECONDS: float = 300.0
45
 
46
- # DeepInfra API endpoint URL
47
- DEEPINFRA_API_URL: str = "https://api.deepinfra.com/v1/openai/chat/completions"
48
 
49
  # =============================================================================
50
  # Concurrency Settings
 
7
  """
8
 
9
  # =============================================================================
10
+ # AI Model Configuration (HuggingFace)
11
  # =============================================================================
12
 
13
+ # Default HuggingFace model to use for grading.
14
+ # Using Llama 2 70B for high-quality grading with good instruction following.
15
  # Alternative options:
16
+ # - "mistralai/Mistral-7B-Instruct-v0.1" (faster, smaller)
17
+ # - "meta-llama/Llama-2-70b-chat-hf" (high quality)
18
+ # - "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO" (excellent for long-form)
19
+ HF_MODEL_DEFAULT: str = "meta-llama/Llama-2-70b-chat-hf"
20
 
21
  # DeepSeek OCR model for extracting text from images and scanned documents
22
  DEEPSK_OCR_MODEL: str = "deepseek-ai/DeepSeek-OCR"
 
36
  # HTTP Client Settings
37
  # =============================================================================
38
 
39
+ # Timeout for HTTP requests to HuggingFace API (in seconds).
40
  # Set to 180s to accommodate larger models and detailed responses.
41
  HTTP_TIMEOUT_SECONDS: float = 180.0
42
 
43
  # Timeout for OCR requests (longer due to image processing)
44
  OCR_TIMEOUT_SECONDS: float = 300.0
45
 
46
+ # HuggingFace Inference API endpoint URL
47
+ HF_API_URL: str = "https://api-inference.huggingface.co/models"
48
 
49
  # =============================================================================
50
  # Concurrency Settings