diff --git a/app/services/ai_generation.py b/app/services/ai_generation.py index 1c2cf5c..097c056 100644 --- a/app/services/ai_generation.py +++ b/app/services/ai_generation.py @@ -8,6 +8,7 @@ Implements caching, user-level reuse checking, and prompt engineering. import json import logging import re +import ast from typing import Any, Dict, Literal, Optional, Union import httpx @@ -120,34 +121,17 @@ def parse_ai_response(response_text: str) -> Optional[GeneratedQuestion]: if not response_text: return None - # Clean the response text cleaned = response_text.strip() + candidates = _extract_json_candidates(cleaned) + for candidate in candidates: + candidate_clean = _sanitize_json_candidate(candidate) + parsed = _try_parse_json_like(candidate_clean) + if isinstance(parsed, dict): + question = validate_and_create_question(parsed) + if question: + return question - # Try to extract JSON from code blocks if present - json_patterns = [ - r"```json\s*([\s\S]*?)\s*```", # ```json ... ``` - r"```\s*([\s\S]*?)\s*```", # ``` ... ``` - r"(\{[\s\S]*\})", # Raw JSON object - ] - - for pattern in json_patterns: - match = re.search(pattern, cleaned) - if match: - json_str = match.group(1).strip() - try: - data = json.loads(json_str) - return validate_and_create_question(data) - except json.JSONDecodeError: - continue - - # Try parsing the entire response as JSON - try: - data = json.loads(cleaned) - return validate_and_create_question(data) - except json.JSONDecodeError: - pass - - logger.warning(f"Failed to parse AI response: {cleaned[:200]}...") + logger.warning(f"Failed to parse AI response: {cleaned[:240]}...") return None @@ -161,15 +145,14 @@ def validate_and_create_question(data: Dict[str, Any]) -> Optional[GeneratedQues Returns: GeneratedQuestion if valid, None otherwise """ - required_fields = ["stem", "options", "correct"] - if not all(field in data for field in required_fields): - logger.warning(f"Missing required fields in AI response: {data.keys()}") + stem = str(data.get("stem") or data.get("question") or "").strip() + if not stem: + logger.warning(f"Missing question stem in AI response: {data.keys()}") return None - # Validate options - options = data.get("options", {}) - if not isinstance(options, dict): - logger.warning("Options is not a dictionary") + options = _normalize_options(data.get("options")) + if not options: + logger.warning("Options cannot be normalized to A/B/C/D map") return None required_options = {"A", "B", "C", "D"} @@ -177,20 +160,129 @@ def validate_and_create_question(data: Dict[str, Any]) -> Optional[GeneratedQues logger.warning(f"Missing required options: {required_options - set(options.keys())}") return None - # Validate correct answer - correct = str(data.get("correct", "")).upper() + correct = _normalize_correct_answer( + data.get("correct") or data.get("correct_answer") or data.get("answer") + ) if correct not in required_options: logger.warning(f"Invalid correct answer: {correct}") return None return GeneratedQuestion( - stem=str(data["stem"]).strip(), - options={k: str(v).strip() for k, v in options.items()}, + stem=stem, + options=options, correct=correct, - explanation=str(data.get("explanation", "")).strip() or None, + explanation=str(data.get("explanation") or data.get("rationale") or "").strip() or None, ) +def _extract_json_candidates(text: str) -> list[str]: + candidates: list[str] = [] + + code_blocks = re.findall(r"```(?:json)?\s*([\s\S]*?)\s*```", text) + candidates.extend(block.strip() for block in code_blocks if block.strip()) + + balanced = _extract_first_balanced_object(text) + if balanced: + candidates.append(balanced) + + candidates.append(text.strip()) + deduped: list[str] = [] + seen = set() + for candidate in candidates: + if candidate and candidate not in seen: + deduped.append(candidate) + seen.add(candidate) + return deduped + + +def _extract_first_balanced_object(text: str) -> str | None: + start = text.find("{") + if start == -1: + return None + depth = 0 + in_string = False + escape_next = False + for idx in range(start, len(text)): + ch = text[idx] + if escape_next: + escape_next = False + continue + if ch == "\\" and in_string: + escape_next = True + continue + if ch == '"': + in_string = not in_string + continue + if in_string: + continue + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + return text[start: idx + 1] + return None + + +def _sanitize_json_candidate(candidate: str) -> str: + cleaned = candidate.strip().lstrip("\ufeff") + cleaned = cleaned.replace("“", '"').replace("”", '"').replace("’", "'") + cleaned = re.sub(r",\s*([}\]])", r"\1", cleaned) + return cleaned + + +def _try_parse_json_like(candidate: str) -> Any: + try: + return json.loads(candidate) + except json.JSONDecodeError: + pass + try: + # Fallback for Python-like dict outputs using single quotes. + return ast.literal_eval(candidate) + except (ValueError, SyntaxError): + return None + + +def _normalize_options(raw_options: Any) -> dict[str, str]: + if isinstance(raw_options, dict): + normalized = {str(k).strip().upper(): str(v).strip() for k, v in raw_options.items()} + return {k: normalized.get(k, "") for k in ["A", "B", "C", "D"] if normalized.get(k, "")} + + if isinstance(raw_options, list): + mapped: dict[str, str] = {} + for idx, opt in enumerate(raw_options): + if isinstance(opt, dict): + key = str(opt.get("increment") or opt.get("key") or "").strip().upper() + text = str(opt.get("text") or opt.get("label") or opt.get("value") or "").strip() + else: + key = "" + text = str(opt).strip() + if not key and idx < 4: + key = ["A", "B", "C", "D"][idx] + if key in {"A", "B", "C", "D"} and text: + mapped[key] = text + return mapped + + return {} + + +def _normalize_correct_answer(raw_correct: Any) -> str: + if raw_correct is None: + return "" + raw_text = str(raw_correct).strip().upper() + if raw_text in {"A", "B", "C", "D"}: + return raw_text + if raw_text.isdigit(): + idx = int(raw_text) + if 1 <= idx <= 4: + return ["A", "B", "C", "D"][idx - 1] + if 0 <= idx <= 3: + return ["A", "B", "C", "D"][idx] + if raw_text in {"OPTION A", "OPTION B", "OPTION C", "OPTION D"}: + return raw_text[-1] + return raw_text[:1] + + async def call_openrouter_api( prompt: str, model: str, @@ -309,21 +401,23 @@ async def generate_question( target_level=target_level, ) - # Call OpenRouter API - response_text = await call_openrouter_api(prompt, ai_model) + max_generation_attempts = 2 + for attempt in range(1, max_generation_attempts + 1): + response_text = await call_openrouter_api(prompt, ai_model) + if not response_text: + logger.error("No response from OpenRouter API") + continue - if not response_text: - logger.error("No response from OpenRouter API") - return None + generated = parse_ai_response(response_text) + if generated: + return generated - # Parse response - generated = parse_ai_response(response_text) - - if not generated: - logger.error("Failed to parse AI response") - return None - - return generated + logger.warning( + "Failed to parse AI response (attempt %s/%s), retrying", + attempt, + max_generation_attempts, + ) + return None async def check_cache_reuse(