Harden AI response parsing and retry generation on malformed JSON

This commit is contained in:
dwindown
2026-04-28 18:50:58 +07:00
parent c3f7a4463b
commit f91c54f197

View File

@@ -8,6 +8,7 @@ Implements caching, user-level reuse checking, and prompt engineering.
import json import json
import logging import logging
import re import re
import ast
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import httpx import httpx
@@ -120,34 +121,17 @@ def parse_ai_response(response_text: str) -> Optional[GeneratedQuestion]:
if not response_text: if not response_text:
return None return None
# Clean the response text
cleaned = response_text.strip() cleaned = response_text.strip()
candidates = _extract_json_candidates(cleaned)
for candidate in candidates:
candidate_clean = _sanitize_json_candidate(candidate)
parsed = _try_parse_json_like(candidate_clean)
if isinstance(parsed, dict):
question = validate_and_create_question(parsed)
if question:
return question
# Try to extract JSON from code blocks if present logger.warning(f"Failed to parse AI response: {cleaned[:240]}...")
json_patterns = [
r"```json\s*([\s\S]*?)\s*```", # ```json ... ```
r"```\s*([\s\S]*?)\s*```", # ``` ... ```
r"(\{[\s\S]*\})", # Raw JSON object
]
for pattern in json_patterns:
match = re.search(pattern, cleaned)
if match:
json_str = match.group(1).strip()
try:
data = json.loads(json_str)
return validate_and_create_question(data)
except json.JSONDecodeError:
continue
# Try parsing the entire response as JSON
try:
data = json.loads(cleaned)
return validate_and_create_question(data)
except json.JSONDecodeError:
pass
logger.warning(f"Failed to parse AI response: {cleaned[:200]}...")
return None return None
@@ -161,15 +145,14 @@ def validate_and_create_question(data: Dict[str, Any]) -> Optional[GeneratedQues
Returns: Returns:
GeneratedQuestion if valid, None otherwise GeneratedQuestion if valid, None otherwise
""" """
required_fields = ["stem", "options", "correct"] stem = str(data.get("stem") or data.get("question") or "").strip()
if not all(field in data for field in required_fields): if not stem:
logger.warning(f"Missing required fields in AI response: {data.keys()}") logger.warning(f"Missing question stem in AI response: {data.keys()}")
return None return None
# Validate options options = _normalize_options(data.get("options"))
options = data.get("options", {}) if not options:
if not isinstance(options, dict): logger.warning("Options cannot be normalized to A/B/C/D map")
logger.warning("Options is not a dictionary")
return None return None
required_options = {"A", "B", "C", "D"} required_options = {"A", "B", "C", "D"}
@@ -177,20 +160,129 @@ def validate_and_create_question(data: Dict[str, Any]) -> Optional[GeneratedQues
logger.warning(f"Missing required options: {required_options - set(options.keys())}") logger.warning(f"Missing required options: {required_options - set(options.keys())}")
return None return None
# Validate correct answer correct = _normalize_correct_answer(
correct = str(data.get("correct", "")).upper() data.get("correct") or data.get("correct_answer") or data.get("answer")
)
if correct not in required_options: if correct not in required_options:
logger.warning(f"Invalid correct answer: {correct}") logger.warning(f"Invalid correct answer: {correct}")
return None return None
return GeneratedQuestion( return GeneratedQuestion(
stem=str(data["stem"]).strip(), stem=stem,
options={k: str(v).strip() for k, v in options.items()}, options=options,
correct=correct, correct=correct,
explanation=str(data.get("explanation", "")).strip() or None, explanation=str(data.get("explanation") or data.get("rationale") or "").strip() or None,
) )
def _extract_json_candidates(text: str) -> list[str]:
candidates: list[str] = []
code_blocks = re.findall(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
candidates.extend(block.strip() for block in code_blocks if block.strip())
balanced = _extract_first_balanced_object(text)
if balanced:
candidates.append(balanced)
candidates.append(text.strip())
deduped: list[str] = []
seen = set()
for candidate in candidates:
if candidate and candidate not in seen:
deduped.append(candidate)
seen.add(candidate)
return deduped
def _extract_first_balanced_object(text: str) -> str | None:
start = text.find("{")
if start == -1:
return None
depth = 0
in_string = False
escape_next = False
for idx in range(start, len(text)):
ch = text[idx]
if escape_next:
escape_next = False
continue
if ch == "\\" and in_string:
escape_next = True
continue
if ch == '"':
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start: idx + 1]
return None
def _sanitize_json_candidate(candidate: str) -> str:
cleaned = candidate.strip().lstrip("\ufeff")
cleaned = cleaned.replace("", '"').replace("", '"').replace("", "'")
cleaned = re.sub(r",\s*([}\]])", r"\1", cleaned)
return cleaned
def _try_parse_json_like(candidate: str) -> Any:
try:
return json.loads(candidate)
except json.JSONDecodeError:
pass
try:
# Fallback for Python-like dict outputs using single quotes.
return ast.literal_eval(candidate)
except (ValueError, SyntaxError):
return None
def _normalize_options(raw_options: Any) -> dict[str, str]:
if isinstance(raw_options, dict):
normalized = {str(k).strip().upper(): str(v).strip() for k, v in raw_options.items()}
return {k: normalized.get(k, "") for k in ["A", "B", "C", "D"] if normalized.get(k, "")}
if isinstance(raw_options, list):
mapped: dict[str, str] = {}
for idx, opt in enumerate(raw_options):
if isinstance(opt, dict):
key = str(opt.get("increment") or opt.get("key") or "").strip().upper()
text = str(opt.get("text") or opt.get("label") or opt.get("value") or "").strip()
else:
key = ""
text = str(opt).strip()
if not key and idx < 4:
key = ["A", "B", "C", "D"][idx]
if key in {"A", "B", "C", "D"} and text:
mapped[key] = text
return mapped
return {}
def _normalize_correct_answer(raw_correct: Any) -> str:
if raw_correct is None:
return ""
raw_text = str(raw_correct).strip().upper()
if raw_text in {"A", "B", "C", "D"}:
return raw_text
if raw_text.isdigit():
idx = int(raw_text)
if 1 <= idx <= 4:
return ["A", "B", "C", "D"][idx - 1]
if 0 <= idx <= 3:
return ["A", "B", "C", "D"][idx]
if raw_text in {"OPTION A", "OPTION B", "OPTION C", "OPTION D"}:
return raw_text[-1]
return raw_text[:1]
async def call_openrouter_api( async def call_openrouter_api(
prompt: str, prompt: str,
model: str, model: str,
@@ -309,22 +401,24 @@ async def generate_question(
target_level=target_level, target_level=target_level,
) )
# Call OpenRouter API max_generation_attempts = 2
for attempt in range(1, max_generation_attempts + 1):
response_text = await call_openrouter_api(prompt, ai_model) response_text = await call_openrouter_api(prompt, ai_model)
if not response_text: if not response_text:
logger.error("No response from OpenRouter API") logger.error("No response from OpenRouter API")
return None continue
# Parse response
generated = parse_ai_response(response_text) generated = parse_ai_response(response_text)
if generated:
if not generated:
logger.error("Failed to parse AI response")
return None
return generated return generated
logger.warning(
"Failed to parse AI response (attempt %s/%s), retrying",
attempt,
max_generation_attempts,
)
return None
async def check_cache_reuse( async def check_cache_reuse(
tryout_id: str, tryout_id: str,