Checkpoint React frontend migration
This commit is contained in:
950
backend/app/services/ai_generation.py
Normal file
950
backend/app/services/ai_generation.py
Normal file
@@ -0,0 +1,950 @@
|
||||
"""
|
||||
AI Question Generation Service.
|
||||
|
||||
Handles OpenRouter API integration for generating question variants.
|
||||
Implements caching, user-level reuse checking, and prompt engineering.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import ast
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models.item import Item
|
||||
from app.models.ai_generation_run import AIGenerationRun
|
||||
from app.models.tryout import Tryout
|
||||
from app.models.user_answer import UserAnswer
|
||||
from app.schemas.ai import AIModelPricing, AIUsageInfo, GeneratedQuestion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
# OpenRouter API configuration
|
||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
|
||||
OPENROUTER_MODELS_URL = "https://openrouter.ai/api/v1/models"
|
||||
|
||||
# Supported AI models
|
||||
SUPPORTED_MODELS = {
|
||||
settings.OPENROUTER_MODEL_CHEAP: "Mistral Small 4 (Cheap / Fast)",
|
||||
settings.OPENROUTER_MODEL_QWEN: "Qwen 2.5 32B Instruct (Balanced)",
|
||||
settings.OPENROUTER_MODEL_LLAMA: "Llama 3.3 70B (Premium)",
|
||||
}
|
||||
|
||||
# Level mapping for prompts
|
||||
LEVEL_DESCRIPTIONS = {
|
||||
"mudah": "easier (simpler concepts, more straightforward calculations)",
|
||||
"sedang": "medium difficulty",
|
||||
"sulit": "harder (more complex concepts, multi-step reasoning)",
|
||||
}
|
||||
|
||||
OPTION_LABELS = tuple("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
MODEL_PRICING_CACHE_TTL_SECONDS = 60 * 30
|
||||
_model_pricing_cache: dict[str, tuple[float, AIModelPricing | None]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenRouterCallResult:
|
||||
content: str
|
||||
usage: AIUsageInfo | None = None
|
||||
|
||||
|
||||
def get_option_labels(options: Dict[str, str] | None) -> list[str]:
|
||||
labels = {
|
||||
str(key).strip().upper()
|
||||
for key, value in (options or {}).items()
|
||||
if str(key).strip() and str(value).strip()
|
||||
}
|
||||
return [label for label in OPTION_LABELS if label in labels]
|
||||
|
||||
|
||||
def _parse_openrouter_price(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
price = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return price if price >= 0 else None
|
||||
|
||||
|
||||
def _build_pricing(raw_pricing: dict[str, Any] | None) -> AIModelPricing | None:
|
||||
if not raw_pricing:
|
||||
return None
|
||||
prompt = _parse_openrouter_price(raw_pricing.get("prompt"))
|
||||
completion = _parse_openrouter_price(raw_pricing.get("completion"))
|
||||
if prompt is None and completion is None:
|
||||
return None
|
||||
return AIModelPricing(
|
||||
prompt=prompt,
|
||||
completion=completion,
|
||||
prompt_per_million=prompt * 1_000_000 if prompt is not None else None,
|
||||
completion_per_million=completion * 1_000_000 if completion is not None else None,
|
||||
)
|
||||
|
||||
|
||||
async def get_model_pricing(model_id: str) -> AIModelPricing | None:
|
||||
cached = _model_pricing_cache.get(model_id)
|
||||
now = time.monotonic()
|
||||
if cached and now - cached[0] < MODEL_PRICING_CACHE_TTL_SECONDS:
|
||||
return cached[1]
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if settings.OPENROUTER_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {settings.OPENROUTER_API_KEY}"
|
||||
|
||||
try:
|
||||
timeout = httpx.Timeout(min(settings.OPENROUTER_TIMEOUT, 5))
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(OPENROUTER_MODELS_URL, headers=headers)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"OpenRouter models pricing request failed: %s - %s",
|
||||
response.status_code,
|
||||
response.text[:240],
|
||||
)
|
||||
_model_pricing_cache[model_id] = (now, None)
|
||||
return None
|
||||
|
||||
for model in response.json().get("data", []):
|
||||
if model.get("id") == model_id:
|
||||
pricing = _build_pricing(model.get("pricing"))
|
||||
_model_pricing_cache[model_id] = (now, pricing)
|
||||
return pricing
|
||||
except Exception as exc:
|
||||
logger.warning("OpenRouter model pricing lookup failed for %s: %s", model_id, exc)
|
||||
|
||||
_model_pricing_cache[model_id] = (now, None)
|
||||
return None
|
||||
|
||||
|
||||
def _calculate_usage_cost(
|
||||
prompt_tokens: int | None,
|
||||
completion_tokens: int | None,
|
||||
pricing: AIModelPricing | None,
|
||||
provider_cost: Any = None,
|
||||
) -> float | None:
|
||||
provider_cost_value = _parse_openrouter_price(provider_cost)
|
||||
if provider_cost_value is not None:
|
||||
return provider_cost_value
|
||||
if pricing is None:
|
||||
return None
|
||||
cost = 0.0
|
||||
has_cost_component = False
|
||||
if prompt_tokens is not None and pricing.prompt is not None:
|
||||
cost += prompt_tokens * pricing.prompt
|
||||
has_cost_component = True
|
||||
if completion_tokens is not None and pricing.completion is not None:
|
||||
cost += completion_tokens * pricing.completion
|
||||
has_cost_component = True
|
||||
return cost if has_cost_component else None
|
||||
|
||||
|
||||
async def build_usage_info(raw_usage: dict[str, Any] | None, model_id: str) -> AIUsageInfo | None:
|
||||
if not raw_usage:
|
||||
return None
|
||||
|
||||
def token_count(key: str) -> int | None:
|
||||
value = raw_usage.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
prompt_tokens = token_count("prompt_tokens")
|
||||
completion_tokens = token_count("completion_tokens")
|
||||
total_tokens = token_count("total_tokens")
|
||||
if total_tokens is None and (prompt_tokens is not None or completion_tokens is not None):
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
pricing = await get_model_pricing(model_id)
|
||||
cost_usd = _calculate_usage_cost(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
pricing,
|
||||
provider_cost=raw_usage.get("cost"),
|
||||
)
|
||||
return AIUsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost_usd=cost_usd,
|
||||
)
|
||||
|
||||
|
||||
def combine_usage(usages: list[AIUsageInfo | None]) -> AIUsageInfo | None:
|
||||
filtered = [usage for usage in usages if usage is not None]
|
||||
if not filtered:
|
||||
return None
|
||||
|
||||
def summed(field: str) -> int | float | None:
|
||||
values = [getattr(usage, field) for usage in filtered]
|
||||
present = [value for value in values if value is not None]
|
||||
return sum(present) if present else None
|
||||
|
||||
return AIUsageInfo(
|
||||
prompt_tokens=summed("prompt_tokens"),
|
||||
completion_tokens=summed("completion_tokens"),
|
||||
total_tokens=summed("total_tokens"),
|
||||
cost_usd=summed("cost_usd"),
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_template(
|
||||
basis_stem: str,
|
||||
basis_options: Dict[str, str],
|
||||
basis_correct: str,
|
||||
basis_explanation: Optional[str],
|
||||
target_level: Literal["mudah", "sulit"],
|
||||
operator_notes: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate standardized prompt for AI question generation.
|
||||
|
||||
Args:
|
||||
basis_stem: The basis question stem
|
||||
basis_options: The basis question options
|
||||
basis_correct: The basis correct answer
|
||||
basis_explanation: The basis explanation
|
||||
target_level: Target difficulty level
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
level_desc = LEVEL_DESCRIPTIONS.get(target_level, target_level)
|
||||
option_labels = get_option_labels(basis_options) or ["A", "B", "C", "D"]
|
||||
option_count = len(option_labels)
|
||||
option_label_text = ", ".join(option_labels)
|
||||
example_options = {label: f"Option {label} text" for label in option_labels}
|
||||
|
||||
options_text = "\n".join(
|
||||
[f" {key}: {value}" for key, value in basis_options.items()]
|
||||
)
|
||||
|
||||
explanation_text = (
|
||||
f"Explanation: {basis_explanation}"
|
||||
if basis_explanation
|
||||
else "Explanation: (not provided)"
|
||||
)
|
||||
|
||||
notes_block = ""
|
||||
if operator_notes and operator_notes.strip():
|
||||
notes_block = f"""
|
||||
ADDITIONAL OPERATOR NOTES:
|
||||
{operator_notes.strip()}
|
||||
|
||||
Apply these notes as style constraints as long as they do not conflict with correctness.
|
||||
"""
|
||||
|
||||
prompt = f"""You are an educational content creator specializing in creating assessment questions.
|
||||
|
||||
Given a "Sedang" (medium difficulty) question, generate a new question at a different difficulty level.
|
||||
|
||||
BASIS QUESTION (Sedang level):
|
||||
Question: {basis_stem}
|
||||
Options:
|
||||
{options_text}
|
||||
Correct Answer: {basis_correct}
|
||||
{explanation_text}
|
||||
|
||||
TASK:
|
||||
Generate 1 new question that is {level_desc} than the basis question above.
|
||||
{notes_block}
|
||||
|
||||
REQUIREMENTS:
|
||||
1. Keep the SAME topic/subject matter as the basis question
|
||||
2. Use similar context and terminology
|
||||
3. Create exactly {option_count} answer options with labels exactly: {option_label_text}
|
||||
4. Preserve the basis option count and option labels. Do not omit, add, rename, or merge answer options.
|
||||
5. Only ONE correct answer, and it must be one of: {option_label_text}
|
||||
6. Include a clear explanation of why the correct answer is correct
|
||||
7. Make the question noticeably {level_desc} - not just a minor variation
|
||||
8. Follow and preserve the basis question's inline HTML style. Keep structural and inline tags such as <p>, <br>, <strong>, <b>, <em>, <i>, <u>, <sub>, <sup>, and simple inline attributes such as text alignment when the basis uses them.
|
||||
9. Do not escape HTML tags as text. Return HTML markup in the JSON string values exactly as markup.
|
||||
|
||||
OUTPUT FORMAT:
|
||||
Return ONLY a valid JSON object with this exact structure (no markdown, no code blocks):
|
||||
{{"stem": "Your question text here", "options": {json.dumps(example_options, ensure_ascii=False)}, "correct": "{option_labels[0]}", "explanation": "Explanation text here"}}
|
||||
|
||||
Remember: The correct field must be exactly one of: {option_label_text}."""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def parse_ai_response(response_text: str) -> Optional[GeneratedQuestion]:
|
||||
"""
|
||||
Parse AI response to extract question data.
|
||||
|
||||
Handles various response formats including JSON code blocks.
|
||||
|
||||
Args:
|
||||
response_text: Raw AI response text
|
||||
|
||||
Returns:
|
||||
GeneratedQuestion if parsing successful, None otherwise
|
||||
"""
|
||||
if not response_text:
|
||||
return None
|
||||
|
||||
cleaned = response_text.strip()
|
||||
candidates = _extract_json_candidates(cleaned)
|
||||
for candidate in candidates:
|
||||
candidate_clean = _sanitize_json_candidate(candidate)
|
||||
parsed = _try_parse_json_like(candidate_clean)
|
||||
if isinstance(parsed, dict):
|
||||
question = validate_and_create_question(parsed)
|
||||
if question:
|
||||
return question
|
||||
|
||||
logger.warning(f"Failed to parse AI response: {cleaned[:240]}...")
|
||||
return None
|
||||
|
||||
|
||||
def validate_and_create_question(data: Dict[str, Any]) -> Optional[GeneratedQuestion]:
|
||||
"""
|
||||
Validate parsed data and create GeneratedQuestion.
|
||||
|
||||
Args:
|
||||
data: Parsed JSON data
|
||||
|
||||
Returns:
|
||||
GeneratedQuestion if valid, None otherwise
|
||||
"""
|
||||
stem = str(data.get("stem") or data.get("question") or "").strip()
|
||||
if not stem:
|
||||
logger.warning(f"Missing question stem in AI response: {data.keys()}")
|
||||
return None
|
||||
|
||||
options = _normalize_options(data.get("options"))
|
||||
if not options:
|
||||
logger.warning("Options cannot be normalized to a labeled option map")
|
||||
return None
|
||||
|
||||
correct = _normalize_correct_answer(
|
||||
data.get("correct") or data.get("correct_answer") or data.get("answer")
|
||||
)
|
||||
if correct not in set(options.keys()):
|
||||
logger.warning(f"Invalid correct answer: {correct}")
|
||||
return None
|
||||
|
||||
return GeneratedQuestion(
|
||||
stem=stem,
|
||||
options=options,
|
||||
correct=correct,
|
||||
explanation=str(data.get("explanation") or data.get("rationale") or "").strip() or None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_json_candidates(text: str) -> list[str]:
|
||||
candidates: list[str] = []
|
||||
|
||||
code_blocks = re.findall(r"```(?:json)?\s*([\s\S]*?)\s*```", text)
|
||||
candidates.extend(block.strip() for block in code_blocks if block.strip())
|
||||
|
||||
balanced = _extract_first_balanced_object(text)
|
||||
if balanced:
|
||||
candidates.append(balanced)
|
||||
|
||||
candidates.append(text.strip())
|
||||
deduped: list[str] = []
|
||||
seen = set()
|
||||
for candidate in candidates:
|
||||
if candidate and candidate not in seen:
|
||||
deduped.append(candidate)
|
||||
seen.add(candidate)
|
||||
return deduped
|
||||
|
||||
|
||||
def _extract_first_balanced_object(text: str) -> str | None:
|
||||
start = text.find("{")
|
||||
if start == -1:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
for idx in range(start, len(text)):
|
||||
ch = text[idx]
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if ch == "\\" and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start: idx + 1]
|
||||
return None
|
||||
|
||||
|
||||
def _sanitize_json_candidate(candidate: str) -> str:
|
||||
cleaned = candidate.strip().lstrip("\ufeff")
|
||||
cleaned = cleaned.replace("“", '"').replace("”", '"').replace("’", "'")
|
||||
cleaned = re.sub(r",\s*([}\]])", r"\1", cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
def _try_parse_json_like(candidate: str) -> Any:
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
try:
|
||||
# Fallback for Python-like dict outputs using single quotes.
|
||||
return ast.literal_eval(candidate)
|
||||
except (ValueError, SyntaxError):
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_options(raw_options: Any) -> dict[str, str]:
|
||||
if isinstance(raw_options, dict):
|
||||
normalized = {str(k).strip().upper(): str(v).strip() for k, v in raw_options.items()}
|
||||
return {k: normalized[k] for k in OPTION_LABELS if normalized.get(k, "")}
|
||||
|
||||
if isinstance(raw_options, list):
|
||||
mapped: dict[str, str] = {}
|
||||
for idx, opt in enumerate(raw_options):
|
||||
if isinstance(opt, dict):
|
||||
key = str(opt.get("increment") or opt.get("key") or "").strip().upper()
|
||||
text = str(opt.get("text") or opt.get("label") or opt.get("value") or "").strip()
|
||||
else:
|
||||
key = ""
|
||||
text = str(opt).strip()
|
||||
if not key and idx < len(OPTION_LABELS):
|
||||
key = OPTION_LABELS[idx]
|
||||
if key in OPTION_LABELS and text:
|
||||
mapped[key] = text
|
||||
return mapped
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _normalize_correct_answer(raw_correct: Any) -> str:
|
||||
if raw_correct is None:
|
||||
return ""
|
||||
raw_text = str(raw_correct).strip().upper()
|
||||
if raw_text in OPTION_LABELS:
|
||||
return raw_text
|
||||
if raw_text.isdigit():
|
||||
idx = int(raw_text)
|
||||
if 1 <= idx <= len(OPTION_LABELS):
|
||||
return OPTION_LABELS[idx - 1]
|
||||
if 0 <= idx < len(OPTION_LABELS):
|
||||
return OPTION_LABELS[idx]
|
||||
if raw_text.startswith("OPTION ") and raw_text[-1:] in OPTION_LABELS:
|
||||
return raw_text[-1]
|
||||
return raw_text[:1]
|
||||
|
||||
|
||||
def generated_matches_basis_options(generated: GeneratedQuestion, basis_item: Item) -> bool:
|
||||
basis_labels = get_option_labels(basis_item.options)
|
||||
generated_labels = get_option_labels(generated.options)
|
||||
if basis_labels != generated_labels:
|
||||
logger.warning(
|
||||
"Generated option labels do not match basis: basis=%s generated=%s",
|
||||
basis_labels,
|
||||
generated_labels,
|
||||
)
|
||||
return False
|
||||
if generated.correct not in set(basis_labels):
|
||||
logger.warning(
|
||||
"Generated correct answer %s is outside basis labels %s",
|
||||
generated.correct,
|
||||
basis_labels,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def call_openrouter_api(
|
||||
prompt: str,
|
||||
model: str,
|
||||
max_retries: int = 3,
|
||||
) -> Optional[OpenRouterCallResult]:
|
||||
"""
|
||||
Call OpenRouter API to generate question.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send
|
||||
model: AI model to use
|
||||
max_retries: Maximum retry attempts
|
||||
|
||||
Returns:
|
||||
OpenRouterCallResult with response text and usage, or None if failed
|
||||
"""
|
||||
if not settings.OPENROUTER_API_KEY:
|
||||
logger.error("OPENROUTER_API_KEY not configured")
|
||||
return None
|
||||
|
||||
if model not in SUPPORTED_MODELS:
|
||||
logger.error(f"Unsupported AI model: {model}")
|
||||
return None
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings.OPENROUTER_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": "https://github.com/irt-bank-soal",
|
||||
"X-Title": "IRT Bank Soal",
|
||||
}
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
"max_tokens": 2000,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
provider_order = [
|
||||
provider for provider in settings.OPENROUTER_PROVIDER_ORDER if provider.strip()
|
||||
]
|
||||
if provider_order:
|
||||
payload["provider"] = {
|
||||
"order": provider_order,
|
||||
"allow_fallbacks": settings.OPENROUTER_ALLOW_PROVIDER_FALLBACKS,
|
||||
}
|
||||
|
||||
timeout = httpx.Timeout(settings.OPENROUTER_TIMEOUT)
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
OPENROUTER_API_URL,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if not content:
|
||||
logger.warning("OpenRouter response had no message content")
|
||||
return None
|
||||
usage = await build_usage_info(data.get("usage"), model)
|
||||
return OpenRouterCallResult(content=content, usage=usage)
|
||||
logger.warning("No choices in OpenRouter response")
|
||||
return None
|
||||
|
||||
elif response.status_code == 429:
|
||||
# Rate limited - wait and retry
|
||||
logger.warning(f"Rate limited, attempt {attempt + 1}/{max_retries}")
|
||||
if attempt < max_retries - 1:
|
||||
import asyncio
|
||||
await asyncio.sleep(2 ** attempt)
|
||||
continue
|
||||
return None
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
f"OpenRouter API error: {response.status_code} - {response.text}"
|
||||
)
|
||||
return None
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning(f"OpenRouter timeout, attempt {attempt + 1}/{max_retries}")
|
||||
if attempt < max_retries - 1:
|
||||
continue
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenRouter API call failed: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
continue
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def generate_question(
|
||||
basis_item: Item,
|
||||
target_level: Literal["mudah", "sulit"],
|
||||
ai_model: str = settings.OPENROUTER_MODEL_QWEN,
|
||||
operator_notes: Optional[str] = None,
|
||||
) -> Optional[GeneratedQuestion]:
|
||||
"""
|
||||
Generate a new question based on a basis item.
|
||||
|
||||
Args:
|
||||
basis_item: The basis item (must be sedang level)
|
||||
target_level: Target difficulty level
|
||||
ai_model: AI model to use
|
||||
|
||||
Returns:
|
||||
GeneratedQuestion if successful, None otherwise
|
||||
"""
|
||||
# Build prompt
|
||||
prompt = get_prompt_template(
|
||||
basis_stem=basis_item.stem,
|
||||
basis_options=basis_item.options,
|
||||
basis_correct=basis_item.correct_answer,
|
||||
basis_explanation=basis_item.explanation,
|
||||
target_level=target_level,
|
||||
operator_notes=operator_notes,
|
||||
)
|
||||
|
||||
max_generation_attempts = 3
|
||||
for attempt in range(1, max_generation_attempts + 1):
|
||||
api_result = await call_openrouter_api(prompt, ai_model)
|
||||
if not api_result:
|
||||
logger.error("No response from OpenRouter API")
|
||||
continue
|
||||
|
||||
generated = parse_ai_response(api_result.content)
|
||||
if generated and generated_matches_basis_options(generated, basis_item):
|
||||
generated = generated.model_copy(update={"usage": api_result.usage})
|
||||
return generated
|
||||
|
||||
logger.warning(
|
||||
"Failed to parse or validate AI response (attempt %s/%s), retrying",
|
||||
attempt,
|
||||
max_generation_attempts,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def check_cache_reuse(
|
||||
tryout_id: str,
|
||||
slot: int,
|
||||
level: str,
|
||||
wp_user_id: str,
|
||||
website_id: int,
|
||||
db: AsyncSession,
|
||||
) -> Optional[Item]:
|
||||
"""
|
||||
Check if there's a cached item that the user hasn't answered yet.
|
||||
|
||||
Query DB for existing item matching (tryout_id, slot, level).
|
||||
Check if user already answered this item at this difficulty level.
|
||||
|
||||
Args:
|
||||
tryout_id: Tryout identifier
|
||||
slot: Question slot
|
||||
level: Difficulty level
|
||||
wp_user_id: WordPress user ID
|
||||
website_id: Website identifier
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Cached item if found and user hasn't answered, None otherwise
|
||||
"""
|
||||
# Find existing items at this slot/level
|
||||
result = await db.execute(
|
||||
select(Item).where(
|
||||
and_(
|
||||
Item.tryout_id == tryout_id,
|
||||
Item.website_id == website_id,
|
||||
Item.slot == slot,
|
||||
Item.level == level,
|
||||
)
|
||||
)
|
||||
)
|
||||
existing_items = result.scalars().all()
|
||||
|
||||
if not existing_items:
|
||||
return None
|
||||
|
||||
# Check each item to find one the user hasn't answered
|
||||
for item in existing_items:
|
||||
# Check if user has answered this item
|
||||
answer_result = await db.execute(
|
||||
select(UserAnswer).where(
|
||||
and_(
|
||||
UserAnswer.item_id == item.id,
|
||||
UserAnswer.wp_user_id == wp_user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
user_answer = answer_result.scalar_one_or_none()
|
||||
|
||||
if user_answer is None:
|
||||
# User hasn't answered this item - can reuse
|
||||
logger.info(
|
||||
f"Cache hit for tryout={tryout_id}, slot={slot}, level={level}, "
|
||||
f"item_id={item.id}, user={wp_user_id}"
|
||||
)
|
||||
return item
|
||||
|
||||
# All items have been answered by this user
|
||||
logger.info(
|
||||
f"Cache miss (user answered all) for tryout={tryout_id}, slot={slot}, "
|
||||
f"level={level}, user={wp_user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def generate_with_cache_check(
|
||||
tryout_id: str,
|
||||
slot: int,
|
||||
level: Literal["mudah", "sulit"],
|
||||
wp_user_id: str,
|
||||
website_id: int,
|
||||
db: AsyncSession,
|
||||
ai_model: str = settings.OPENROUTER_MODEL_QWEN,
|
||||
) -> tuple[Optional[Union[Item, GeneratedQuestion]], bool]:
|
||||
"""
|
||||
Generate question with cache checking.
|
||||
|
||||
First checks if AI generation is enabled for the tryout.
|
||||
Then checks for cached items the user hasn't answered.
|
||||
If cache miss, generates new question via AI.
|
||||
|
||||
Args:
|
||||
tryout_id: Tryout identifier
|
||||
slot: Question slot
|
||||
level: Target difficulty level
|
||||
wp_user_id: WordPress user ID
|
||||
website_id: Website identifier
|
||||
db: Database session
|
||||
ai_model: AI model to use
|
||||
|
||||
Returns:
|
||||
Tuple of (item/question or None, is_cached)
|
||||
"""
|
||||
# Check if AI generation is enabled for this tryout
|
||||
tryout_result = await db.execute(
|
||||
select(Tryout).where(
|
||||
and_(
|
||||
Tryout.tryout_id == tryout_id,
|
||||
Tryout.website_id == website_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
tryout = tryout_result.scalar_one_or_none()
|
||||
|
||||
if tryout and not tryout.ai_generation_enabled:
|
||||
logger.info(f"AI generation disabled for tryout={tryout_id}")
|
||||
# Still check cache even if AI disabled
|
||||
cached_item = await check_cache_reuse(
|
||||
tryout_id, slot, level, wp_user_id, website_id, db
|
||||
)
|
||||
if cached_item:
|
||||
return cached_item, True
|
||||
return None, False
|
||||
|
||||
# Check cache for reusable item
|
||||
cached_item = await check_cache_reuse(
|
||||
tryout_id, slot, level, wp_user_id, website_id, db
|
||||
)
|
||||
|
||||
if cached_item:
|
||||
return cached_item, True
|
||||
|
||||
# Cache miss - need to generate
|
||||
# Get basis item (sedang level at same slot)
|
||||
basis_result = await db.execute(
|
||||
select(Item).where(
|
||||
and_(
|
||||
Item.tryout_id == tryout_id,
|
||||
Item.website_id == website_id,
|
||||
Item.slot == slot,
|
||||
Item.level == "sedang",
|
||||
)
|
||||
).limit(1)
|
||||
)
|
||||
basis_item = basis_result.scalar_one_or_none()
|
||||
|
||||
if not basis_item:
|
||||
logger.error(
|
||||
f"No basis item found for tryout={tryout_id}, slot={slot}"
|
||||
)
|
||||
return None, False
|
||||
|
||||
# Generate new question
|
||||
generated = await generate_question(basis_item, level, ai_model)
|
||||
|
||||
if not generated:
|
||||
logger.error(
|
||||
f"Failed to generate question for tryout={tryout_id}, slot={slot}, level={level}"
|
||||
)
|
||||
return None, False
|
||||
|
||||
return generated, False
|
||||
|
||||
|
||||
async def save_ai_question(
|
||||
generated_data: GeneratedQuestion,
|
||||
tryout_id: str,
|
||||
website_id: int,
|
||||
basis_item_id: int,
|
||||
slot: int,
|
||||
level: Literal["mudah", "sedang", "sulit"],
|
||||
ai_model: str,
|
||||
db: AsyncSession,
|
||||
generation_run_id: int | None = None,
|
||||
source_snapshot_question_id: int | None = None,
|
||||
variant_status: str = "draft",
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Save AI-generated question to database.
|
||||
|
||||
Args:
|
||||
generated_data: Generated question data
|
||||
tryout_id: Tryout identifier
|
||||
website_id: Website identifier
|
||||
basis_item_id: Basis item ID
|
||||
slot: Question slot
|
||||
level: Difficulty level
|
||||
ai_model: AI model used
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created item ID or None if failed
|
||||
"""
|
||||
try:
|
||||
new_item = Item(
|
||||
tryout_id=tryout_id,
|
||||
website_id=website_id,
|
||||
slot=slot,
|
||||
level=level,
|
||||
stem=generated_data.stem,
|
||||
options=generated_data.options,
|
||||
correct_answer=generated_data.correct,
|
||||
explanation=generated_data.explanation,
|
||||
generated_by="ai",
|
||||
ai_model=ai_model,
|
||||
basis_item_id=basis_item_id,
|
||||
generation_run_id=generation_run_id,
|
||||
source_snapshot_question_id=source_snapshot_question_id,
|
||||
variant_status=variant_status,
|
||||
calibrated=False,
|
||||
ctt_p=None,
|
||||
ctt_bobot=None,
|
||||
ctt_category=None,
|
||||
irt_b=None,
|
||||
irt_se=None,
|
||||
calibration_sample_size=0,
|
||||
)
|
||||
|
||||
db.add(new_item)
|
||||
await db.flush() # Get the ID without committing
|
||||
|
||||
logger.info(
|
||||
f"Saved AI-generated item: id={new_item.id}, tryout={tryout_id}, "
|
||||
f"slot={slot}, level={level}, model={ai_model}"
|
||||
)
|
||||
|
||||
return new_item.id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save AI-generated question: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def create_generation_run(
|
||||
basis_item_id: int,
|
||||
target_level: Literal["mudah", "sulit"],
|
||||
requested_count: int,
|
||||
model: str,
|
||||
created_by: str,
|
||||
db: AsyncSession,
|
||||
source_snapshot_question_id: int | None = None,
|
||||
operator_notes: str | None = None,
|
||||
prompt_version: str = "v1",
|
||||
) -> int:
|
||||
run = AIGenerationRun(
|
||||
basis_item_id=basis_item_id,
|
||||
source_snapshot_question_id=source_snapshot_question_id,
|
||||
target_level=target_level,
|
||||
requested_count=requested_count,
|
||||
model=model,
|
||||
prompt_version=prompt_version,
|
||||
operator_notes=operator_notes,
|
||||
created_by=created_by,
|
||||
)
|
||||
db.add(run)
|
||||
await db.flush()
|
||||
return int(run.id)
|
||||
|
||||
|
||||
async def generate_questions_batch(
|
||||
basis_item: Item,
|
||||
target_level: Literal["mudah", "sulit"],
|
||||
ai_model: str,
|
||||
count: int,
|
||||
operator_notes: Optional[str] = None,
|
||||
) -> list[GeneratedQuestion]:
|
||||
generated_items: list[GeneratedQuestion] = []
|
||||
for _ in range(count):
|
||||
generated = await generate_question(
|
||||
basis_item=basis_item,
|
||||
target_level=target_level,
|
||||
ai_model=ai_model,
|
||||
operator_notes=operator_notes,
|
||||
)
|
||||
if generated is not None:
|
||||
generated_items.append(generated)
|
||||
return generated_items
|
||||
|
||||
|
||||
async def get_ai_stats(db: AsyncSession, website_id: int | None = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get AI generation statistics.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
filters = [Item.generated_by == "ai"]
|
||||
if website_id is not None:
|
||||
filters.append(Item.website_id == website_id)
|
||||
|
||||
# Total AI-generated items
|
||||
total_result = await db.execute(select(func.count(Item.id)).where(*filters))
|
||||
total_ai_items = total_result.scalar() or 0
|
||||
|
||||
# Items by model
|
||||
model_result = await db.execute(
|
||||
select(Item.ai_model, func.count(Item.id))
|
||||
.where(*filters)
|
||||
.where(Item.ai_model.isnot(None))
|
||||
.group_by(Item.ai_model)
|
||||
)
|
||||
items_by_model = {row[0]: row[1] for row in model_result.all()}
|
||||
|
||||
# Note: Cache hit rate would need to be tracked separately
|
||||
# This is a placeholder for now
|
||||
return {
|
||||
"total_ai_items": total_ai_items,
|
||||
"items_by_model": items_by_model,
|
||||
"cache_hit_rate": 0.0,
|
||||
"total_cache_hits": 0,
|
||||
"total_requests": 0,
|
||||
}
|
||||
|
||||
|
||||
def validate_ai_model(model: str) -> bool:
|
||||
"""
|
||||
Validate that the AI model is supported.
|
||||
|
||||
Args:
|
||||
model: AI model identifier
|
||||
|
||||
Returns:
|
||||
True if model is supported
|
||||
"""
|
||||
return model in SUPPORTED_MODELS
|
||||
Reference in New Issue
Block a user