Harden AI response parsing and retry generation on malformed JSON
This commit is contained in:
@@ -8,6 +8,7 @@ Implements caching, user-level reuse checking, and prompt engineering.
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import ast
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
@@ -120,34 +121,17 @@ def parse_ai_response(response_text: str) -> Optional[GeneratedQuestion]:
|
||||
if not response_text:
|
||||
return None
|
||||
|
||||
# Clean the response text
|
||||
cleaned = response_text.strip()
|
||||
candidates = _extract_json_candidates(cleaned)
|
||||
for candidate in candidates:
|
||||
candidate_clean = _sanitize_json_candidate(candidate)
|
||||
parsed = _try_parse_json_like(candidate_clean)
|
||||
if isinstance(parsed, dict):
|
||||
question = validate_and_create_question(parsed)
|
||||
if question:
|
||||
return question
|
||||
|
||||
# Try to extract JSON from code blocks if present
|
||||
json_patterns = [
|
||||
r"```json\s*([\s\S]*?)\s*```", # ```json ... ```
|
||||
r"```\s*([\s\S]*?)\s*```", # ``` ... ```
|
||||
r"(\{[\s\S]*\})", # Raw JSON object
|
||||
]
|
||||
|
||||
for pattern in json_patterns:
|
||||
match = re.search(pattern, cleaned)
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
return validate_and_create_question(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Try parsing the entire response as JSON
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
return validate_and_create_question(data)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
logger.warning(f"Failed to parse AI response: {cleaned[:200]}...")
|
||||
logger.warning(f"Failed to parse AI response: {cleaned[:240]}...")
|
||||
return None
|
||||
|
||||
|
||||
@@ -161,15 +145,14 @@ def validate_and_create_question(data: Dict[str, Any]) -> Optional[GeneratedQues
|
||||
Returns:
|
||||
GeneratedQuestion if valid, None otherwise
|
||||
"""
|
||||
required_fields = ["stem", "options", "correct"]
|
||||
if not all(field in data for field in required_fields):
|
||||
logger.warning(f"Missing required fields in AI response: {data.keys()}")
|
||||
stem = str(data.get("stem") or data.get("question") or "").strip()
|
||||
if not stem:
|
||||
logger.warning(f"Missing question stem in AI response: {data.keys()}")
|
||||
return None
|
||||
|
||||
# Validate options
|
||||
options = data.get("options", {})
|
||||
if not isinstance(options, dict):
|
||||
logger.warning("Options is not a dictionary")
|
||||
options = _normalize_options(data.get("options"))
|
||||
if not options:
|
||||
logger.warning("Options cannot be normalized to A/B/C/D map")
|
||||
return None
|
||||
|
||||
required_options = {"A", "B", "C", "D"}
|
||||
@@ -177,20 +160,129 @@ def validate_and_create_question(data: Dict[str, Any]) -> Optional[GeneratedQues
|
||||
logger.warning(f"Missing required options: {required_options - set(options.keys())}")
|
||||
return None
|
||||
|
||||
# Validate correct answer
|
||||
correct = str(data.get("correct", "")).upper()
|
||||
correct = _normalize_correct_answer(
|
||||
data.get("correct") or data.get("correct_answer") or data.get("answer")
|
||||
)
|
||||
if correct not in required_options:
|
||||
logger.warning(f"Invalid correct answer: {correct}")
|
||||
return None
|
||||
|
||||
return GeneratedQuestion(
|
||||
stem=str(data["stem"]).strip(),
|
||||
options={k: str(v).strip() for k, v in options.items()},
|
||||
stem=stem,
|
||||
options=options,
|
||||
correct=correct,
|
||||
explanation=str(data.get("explanation", "")).strip() or None,
|
||||
explanation=str(data.get("explanation") or data.get("rationale") or "").strip() or None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_json_candidates(text: str) -> list[str]:
|
||||
candidates: list[str] = []
|
||||
|
||||
code_blocks = re.findall(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
|
||||
candidates.extend(block.strip() for block in code_blocks if block.strip())
|
||||
|
||||
balanced = _extract_first_balanced_object(text)
|
||||
if balanced:
|
||||
candidates.append(balanced)
|
||||
|
||||
candidates.append(text.strip())
|
||||
deduped: list[str] = []
|
||||
seen = set()
|
||||
for candidate in candidates:
|
||||
if candidate and candidate not in seen:
|
||||
deduped.append(candidate)
|
||||
seen.add(candidate)
|
||||
return deduped
|
||||
|
||||
|
||||
def _extract_first_balanced_object(text: str) -> str | None:
|
||||
start = text.find("{")
|
||||
if start == -1:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
for idx in range(start, len(text)):
|
||||
ch = text[idx]
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if ch == "\\" and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start: idx + 1]
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_json_candidate(candidate: str) -> str:
|
||||
cleaned = candidate.strip().lstrip("\ufeff")
|
||||
cleaned = cleaned.replace("“", '"').replace("”", '"').replace("’", "'")
|
||||
cleaned = re.sub(r",\s*([}\]])", r"\1", cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
def _try_parse_json_like(candidate: str) -> Any:
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
try:
|
||||
# Fallback for Python-like dict outputs using single quotes.
|
||||
return ast.literal_eval(candidate)
|
||||
except (ValueError, SyntaxError):
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_options(raw_options: Any) -> dict[str, str]:
|
||||
if isinstance(raw_options, dict):
|
||||
normalized = {str(k).strip().upper(): str(v).strip() for k, v in raw_options.items()}
|
||||
return {k: normalized.get(k, "") for k in ["A", "B", "C", "D"] if normalized.get(k, "")}
|
||||
|
||||
if isinstance(raw_options, list):
|
||||
mapped: dict[str, str] = {}
|
||||
for idx, opt in enumerate(raw_options):
|
||||
if isinstance(opt, dict):
|
||||
key = str(opt.get("increment") or opt.get("key") or "").strip().upper()
|
||||
text = str(opt.get("text") or opt.get("label") or opt.get("value") or "").strip()
|
||||
else:
|
||||
key = ""
|
||||
text = str(opt).strip()
|
||||
if not key and idx < 4:
|
||||
key = ["A", "B", "C", "D"][idx]
|
||||
if key in {"A", "B", "C", "D"} and text:
|
||||
mapped[key] = text
|
||||
return mapped
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _normalize_correct_answer(raw_correct: Any) -> str:
|
||||
if raw_correct is None:
|
||||
return ""
|
||||
raw_text = str(raw_correct).strip().upper()
|
||||
if raw_text in {"A", "B", "C", "D"}:
|
||||
return raw_text
|
||||
if raw_text.isdigit():
|
||||
idx = int(raw_text)
|
||||
if 1 <= idx <= 4:
|
||||
return ["A", "B", "C", "D"][idx - 1]
|
||||
if 0 <= idx <= 3:
|
||||
return ["A", "B", "C", "D"][idx]
|
||||
if raw_text in {"OPTION A", "OPTION B", "OPTION C", "OPTION D"}:
|
||||
return raw_text[-1]
|
||||
return raw_text[:1]
|
||||
|
||||
|
||||
async def call_openrouter_api(
|
||||
prompt: str,
|
||||
model: str,
|
||||
@@ -309,22 +401,24 @@ async def generate_question(
|
||||
target_level=target_level,
|
||||
)
|
||||
|
||||
# Call OpenRouter API
|
||||
max_generation_attempts = 2
|
||||
for attempt in range(1, max_generation_attempts + 1):
|
||||
response_text = await call_openrouter_api(prompt, ai_model)
|
||||
|
||||
if not response_text:
|
||||
logger.error("No response from OpenRouter API")
|
||||
return None
|
||||
continue
|
||||
|
||||
# Parse response
|
||||
generated = parse_ai_response(response_text)
|
||||
|
||||
if not generated:
|
||||
logger.error("Failed to parse AI response")
|
||||
return None
|
||||
|
||||
if generated:
|
||||
return generated
|
||||
|
||||
logger.warning(
|
||||
"Failed to parse AI response (attempt %s/%s), retrying",
|
||||
attempt,
|
||||
max_generation_attempts,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def check_cache_reuse(
|
||||
tryout_id: str,
|
||||
|
||||
Reference in New Issue
Block a user