Harden AI response parsing and retry generation on malformed JSON

This commit is contained in:
dwindown
2026-04-28 18:50:58 +07:00
parent c3f7a4463b
commit f91c54f197

View File

@@ -8,6 +8,7 @@ Implements caching, user-level reuse checking, and prompt engineering.
import json
import logging
import re
import ast
from typing import Any, Dict, Literal, Optional, Union
import httpx
@@ -120,34 +121,17 @@ def parse_ai_response(response_text: str) -> Optional[GeneratedQuestion]:
if not response_text:
return None
# Clean the response text
cleaned = response_text.strip()
candidates = _extract_json_candidates(cleaned)
for candidate in candidates:
candidate_clean = _sanitize_json_candidate(candidate)
parsed = _try_parse_json_like(candidate_clean)
if isinstance(parsed, dict):
question = validate_and_create_question(parsed)
if question:
return question
# Try to extract JSON from code blocks if present
json_patterns = [
r"```json\s*([\s\S]*?)\s*```", # ```json ... ```
r"```\s*([\s\S]*?)\s*```", # ``` ... ```
r"(\{[\s\S]*\})", # Raw JSON object
]
for pattern in json_patterns:
match = re.search(pattern, cleaned)
if match:
json_str = match.group(1).strip()
try:
data = json.loads(json_str)
return validate_and_create_question(data)
except json.JSONDecodeError:
continue
# Try parsing the entire response as JSON
try:
data = json.loads(cleaned)
return validate_and_create_question(data)
except json.JSONDecodeError:
pass
logger.warning(f"Failed to parse AI response: {cleaned[:200]}...")
logger.warning(f"Failed to parse AI response: {cleaned[:240]}...")
return None
@@ -161,15 +145,14 @@ def validate_and_create_question(data: Dict[str, Any]) -> Optional[GeneratedQues
Returns:
GeneratedQuestion if valid, None otherwise
"""
required_fields = ["stem", "options", "correct"]
if not all(field in data for field in required_fields):
logger.warning(f"Missing required fields in AI response: {data.keys()}")
stem = str(data.get("stem") or data.get("question") or "").strip()
if not stem:
logger.warning(f"Missing question stem in AI response: {data.keys()}")
return None
# Validate options
options = data.get("options", {})
if not isinstance(options, dict):
logger.warning("Options is not a dictionary")
options = _normalize_options(data.get("options"))
if not options:
logger.warning("Options cannot be normalized to A/B/C/D map")
return None
required_options = {"A", "B", "C", "D"}
@@ -177,20 +160,129 @@ def validate_and_create_question(data: Dict[str, Any]) -> Optional[GeneratedQues
logger.warning(f"Missing required options: {required_options - set(options.keys())}")
return None
# Validate correct answer
correct = str(data.get("correct", "")).upper()
correct = _normalize_correct_answer(
data.get("correct") or data.get("correct_answer") or data.get("answer")
)
if correct not in required_options:
logger.warning(f"Invalid correct answer: {correct}")
return None
return GeneratedQuestion(
stem=str(data["stem"]).strip(),
options={k: str(v).strip() for k, v in options.items()},
stem=stem,
options=options,
correct=correct,
explanation=str(data.get("explanation", "")).strip() or None,
explanation=str(data.get("explanation") or data.get("rationale") or "").strip() or None,
)
def _extract_json_candidates(text: str) -> list[str]:
candidates: list[str] = []
code_blocks = re.findall(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
candidates.extend(block.strip() for block in code_blocks if block.strip())
balanced = _extract_first_balanced_object(text)
if balanced:
candidates.append(balanced)
candidates.append(text.strip())
deduped: list[str] = []
seen = set()
for candidate in candidates:
if candidate and candidate not in seen:
deduped.append(candidate)
seen.add(candidate)
return deduped
def _extract_first_balanced_object(text: str) -> str | None:
start = text.find("{")
if start == -1:
return None
depth = 0
in_string = False
escape_next = False
for idx in range(start, len(text)):
ch = text[idx]
if escape_next:
escape_next = False
continue
if ch == "\\" and in_string:
escape_next = True
continue
if ch == '"':
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start: idx + 1]
return None
def _sanitize_json_candidate(candidate: str) -> str:
cleaned = candidate.strip().lstrip("\ufeff")
cleaned = cleaned.replace("", '"').replace("", '"').replace("", "'")
cleaned = re.sub(r",\s*([}\]])", r"\1", cleaned)
return cleaned
def _try_parse_json_like(candidate: str) -> Any:
try:
return json.loads(candidate)
except json.JSONDecodeError:
pass
try:
# Fallback for Python-like dict outputs using single quotes.
return ast.literal_eval(candidate)
except (ValueError, SyntaxError):
return None
def _normalize_options(raw_options: Any) -> dict[str, str]:
if isinstance(raw_options, dict):
normalized = {str(k).strip().upper(): str(v).strip() for k, v in raw_options.items()}
return {k: normalized.get(k, "") for k in ["A", "B", "C", "D"] if normalized.get(k, "")}
if isinstance(raw_options, list):
mapped: dict[str, str] = {}
for idx, opt in enumerate(raw_options):
if isinstance(opt, dict):
key = str(opt.get("increment") or opt.get("key") or "").strip().upper()
text = str(opt.get("text") or opt.get("label") or opt.get("value") or "").strip()
else:
key = ""
text = str(opt).strip()
if not key and idx < 4:
key = ["A", "B", "C", "D"][idx]
if key in {"A", "B", "C", "D"} and text:
mapped[key] = text
return mapped
return {}
def _normalize_correct_answer(raw_correct: Any) -> str:
if raw_correct is None:
return ""
raw_text = str(raw_correct).strip().upper()
if raw_text in {"A", "B", "C", "D"}:
return raw_text
if raw_text.isdigit():
idx = int(raw_text)
if 1 <= idx <= 4:
return ["A", "B", "C", "D"][idx - 1]
if 0 <= idx <= 3:
return ["A", "B", "C", "D"][idx]
if raw_text in {"OPTION A", "OPTION B", "OPTION C", "OPTION D"}:
return raw_text[-1]
return raw_text[:1]
async def call_openrouter_api(
prompt: str,
model: str,
@@ -309,22 +401,24 @@ async def generate_question(
target_level=target_level,
)
# Call OpenRouter API
max_generation_attempts = 2
for attempt in range(1, max_generation_attempts + 1):
response_text = await call_openrouter_api(prompt, ai_model)
if not response_text:
logger.error("No response from OpenRouter API")
return None
continue
# Parse response
generated = parse_ai_response(response_text)
if not generated:
logger.error("Failed to parse AI response")
return None
if generated:
return generated
logger.warning(
"Failed to parse AI response (attempt %s/%s), retrying",
attempt,
max_generation_attempts,
)
return None
async def check_cache_reuse(
tryout_id: str,