531 lines
17 KiB
Python
531 lines
17 KiB
Python
"""
|
|
AI Generation Router.
|
|
|
|
Admin endpoints for AI question generation playground.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.auth import (
|
|
AuthContext,
|
|
ensure_website_scope_matches,
|
|
get_auth_context,
|
|
require_website_auth,
|
|
)
|
|
from app.core.rate_limit import enforce_rate_limit
|
|
from app.database import get_db
|
|
from app.models.item import Item
|
|
from app.schemas.ai import (
|
|
AIBatchGeneratedItem,
|
|
AIGenerateBatchRequest,
|
|
AIGenerateBatchResponse,
|
|
AIGeneratePreviewRequest,
|
|
AIGeneratePreviewResponse,
|
|
AISaveRequest,
|
|
AISaveResponse,
|
|
AIStatsResponse,
|
|
)
|
|
from app.services.ai_generation import (
|
|
SUPPORTED_MODELS,
|
|
combine_usage,
|
|
create_generation_run,
|
|
generate_question,
|
|
generate_questions_batch,
|
|
generated_matches_basis_options,
|
|
get_ai_stats,
|
|
get_model_pricing,
|
|
save_ai_question,
|
|
validate_ai_model,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
settings = get_settings()
|
|
|
|
router = APIRouter(prefix="/admin/ai", tags=["admin", "ai-generation"])
|
|
|
|
|
|
def _validate_original_basis_item(basis_item: Item) -> None:
|
|
if basis_item.level != "sedang":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Basis item must be 'sedang' level, got: {basis_item.level}",
|
|
)
|
|
if basis_item.generated_by == "ai":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Basis item must be an original question, not an AI-generated variant.",
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/generate-preview",
|
|
response_model=AIGeneratePreviewResponse,
|
|
summary="Preview AI-generated question",
|
|
description="""
|
|
Generate a question preview using AI without saving to database.
|
|
|
|
This is an admin playground endpoint for testing AI generation quality.
|
|
Admins can retry unlimited times until satisfied with the result.
|
|
|
|
Requirements:
|
|
- basis_item_id must reference an existing item at 'sedang' level
|
|
- target_level must be 'mudah' or 'sulit'
|
|
- ai_model must be a supported OpenRouter model
|
|
""",
|
|
responses={
|
|
200: {"description": "Question generated successfully (preview mode)"},
|
|
400: {"description": "Invalid request (wrong level, unsupported model)"},
|
|
404: {"description": "Basis item not found"},
|
|
500: {"description": "AI generation failed"},
|
|
},
|
|
)
|
|
async def generate_preview(
|
|
request_http: Request,
|
|
request: AIGeneratePreviewRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
auth: AuthContext = Depends(get_auth_context),
|
|
) -> AIGeneratePreviewResponse:
|
|
"""
|
|
Generate AI question preview (no database save).
|
|
|
|
- **basis_item_id**: ID of the sedang-level question to base generation on
|
|
- **target_level**: Target difficulty (mudah/sulit)
|
|
- **ai_model**: OpenRouter model to use (default: qwen/qwen2.5-32b-instruct)
|
|
"""
|
|
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
|
await enforce_rate_limit(
|
|
request_http,
|
|
scope="ai.generate_preview",
|
|
max_requests=40,
|
|
window_seconds=300,
|
|
)
|
|
|
|
# Validate AI model
|
|
if not validate_ai_model(request.ai_model):
|
|
supported = ", ".join(SUPPORTED_MODELS.keys())
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unsupported AI model: {request.ai_model}. "
|
|
f"Supported models: {supported}",
|
|
)
|
|
|
|
# Fetch basis item
|
|
result = await db.execute(
|
|
select(Item).where(Item.id == request.basis_item_id)
|
|
)
|
|
basis_item = result.scalar_one_or_none()
|
|
|
|
if not basis_item:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Basis item not found: {request.basis_item_id}",
|
|
)
|
|
ensure_website_scope_matches(website_id, basis_item.website_id)
|
|
|
|
_validate_original_basis_item(basis_item)
|
|
|
|
# Generate question
|
|
try:
|
|
generated = await generate_question(
|
|
basis_item=basis_item,
|
|
target_level=request.target_level,
|
|
ai_model=request.ai_model,
|
|
)
|
|
|
|
if not generated:
|
|
return AIGeneratePreviewResponse(
|
|
success=False,
|
|
error="AI generation failed. Please check logs or try again.",
|
|
ai_model=request.ai_model,
|
|
basis_item_id=request.basis_item_id,
|
|
target_level=request.target_level,
|
|
)
|
|
|
|
return AIGeneratePreviewResponse(
|
|
success=True,
|
|
stem=generated.stem,
|
|
options=generated.options,
|
|
correct=generated.correct,
|
|
explanation=generated.explanation,
|
|
usage=generated.usage,
|
|
ai_model=request.ai_model,
|
|
basis_item_id=request.basis_item_id,
|
|
target_level=request.target_level,
|
|
cached=False,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"AI preview generation failed: {e}")
|
|
return AIGeneratePreviewResponse(
|
|
success=False,
|
|
error=f"AI generation error: {str(e)}",
|
|
ai_model=request.ai_model,
|
|
basis_item_id=request.basis_item_id,
|
|
target_level=request.target_level,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/generate-save",
|
|
response_model=AISaveResponse,
|
|
summary="Save AI-generated question",
|
|
description="""
|
|
Save an AI-generated question to the database.
|
|
|
|
This endpoint creates a new Item record with:
|
|
- generated_by='ai'
|
|
- ai_model from request
|
|
- basis_item_id linking to original question
|
|
- calibrated=False (will be calculated later)
|
|
""",
|
|
responses={
|
|
200: {"description": "Question saved successfully"},
|
|
400: {"description": "Invalid request data"},
|
|
404: {"description": "Basis item or tryout not found"},
|
|
500: {"description": "Database save failed"},
|
|
},
|
|
)
|
|
async def generate_save(
|
|
request_http: Request,
|
|
request: AISaveRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
auth: AuthContext = Depends(get_auth_context),
|
|
) -> AISaveResponse:
|
|
"""
|
|
Save AI-generated question to database.
|
|
|
|
- **stem**: Question text
|
|
- **options**: Dict with the same option labels as the basis item
|
|
- **correct**: Correct answer label from the generated options
|
|
- **explanation**: Answer explanation (optional)
|
|
- **tryout_id**: Tryout identifier
|
|
- **website_id**: Website identifier
|
|
- **basis_item_id**: Original item ID this was generated from
|
|
- **slot**: Question slot position
|
|
- **level**: Difficulty level
|
|
- **ai_model**: AI model used for generation
|
|
"""
|
|
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
|
await enforce_rate_limit(
|
|
request_http,
|
|
scope="ai.generate_save",
|
|
max_requests=40,
|
|
window_seconds=300,
|
|
)
|
|
ensure_website_scope_matches(website_id, request.website_id)
|
|
|
|
# Verify basis item exists
|
|
basis_result = await db.execute(
|
|
select(Item).where(Item.id == request.basis_item_id)
|
|
)
|
|
basis_item = basis_result.scalar_one_or_none()
|
|
|
|
if not basis_item:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Basis item not found: {request.basis_item_id}",
|
|
)
|
|
ensure_website_scope_matches(website_id, basis_item.website_id)
|
|
_validate_original_basis_item(basis_item)
|
|
|
|
# Create GeneratedQuestion from request
|
|
from app.schemas.ai import GeneratedQuestion
|
|
|
|
generated_data = GeneratedQuestion(
|
|
stem=request.stem,
|
|
options=request.options,
|
|
correct=request.correct,
|
|
explanation=request.explanation,
|
|
)
|
|
if not generated_matches_basis_options(generated_data, basis_item):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Generated options must match the basis question option labels exactly.",
|
|
)
|
|
|
|
run_id = await create_generation_run(
|
|
basis_item_id=basis_item.id,
|
|
source_snapshot_question_id=basis_item.source_snapshot_question_id,
|
|
target_level=request.level,
|
|
requested_count=1,
|
|
model=request.ai_model,
|
|
created_by=auth.wp_user_id or auth.role,
|
|
db=db,
|
|
)
|
|
|
|
# Save to database
|
|
item_id = await save_ai_question(
|
|
generated_data=generated_data,
|
|
tryout_id=request.tryout_id,
|
|
website_id=request.website_id,
|
|
basis_item_id=request.basis_item_id,
|
|
slot=request.slot,
|
|
level=request.level,
|
|
ai_model=request.ai_model,
|
|
generation_run_id=run_id,
|
|
source_snapshot_question_id=basis_item.source_snapshot_question_id,
|
|
variant_status=request.variant_status,
|
|
db=db,
|
|
)
|
|
|
|
if not item_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to save AI-generated question",
|
|
)
|
|
|
|
return AISaveResponse(
|
|
success=True,
|
|
item_id=item_id,
|
|
run_id=run_id,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/generate-batch",
|
|
response_model=AIGenerateBatchResponse,
|
|
summary="Generate and save AI question batch",
|
|
description="Generate multiple trusted active variants from one medium-level basis question and track the run.",
|
|
)
|
|
async def generate_batch(
|
|
request_http: Request,
|
|
request: AIGenerateBatchRequest,
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
auth: AuthContext = Depends(get_auth_context),
|
|
) -> AIGenerateBatchResponse:
|
|
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
|
await enforce_rate_limit(
|
|
request_http,
|
|
scope="ai.generate_batch",
|
|
max_requests=10,
|
|
window_seconds=300,
|
|
)
|
|
|
|
if not validate_ai_model(request.ai_model):
|
|
supported = ", ".join(SUPPORTED_MODELS.keys())
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unsupported AI model: {request.ai_model}. Supported models: {supported}",
|
|
)
|
|
|
|
result = await db.execute(select(Item).where(Item.id == request.basis_item_id))
|
|
basis_item = result.scalar_one_or_none()
|
|
if not basis_item:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Basis item not found: {request.basis_item_id}",
|
|
)
|
|
ensure_website_scope_matches(website_id, basis_item.website_id)
|
|
_validate_original_basis_item(basis_item)
|
|
|
|
run_id = await create_generation_run(
|
|
basis_item_id=basis_item.id,
|
|
source_snapshot_question_id=basis_item.source_snapshot_question_id,
|
|
target_level=request.target_level,
|
|
requested_count=request.count,
|
|
model=request.ai_model,
|
|
created_by=auth.wp_user_id or auth.role,
|
|
operator_notes=request.operator_notes,
|
|
db=db,
|
|
)
|
|
|
|
generated_questions = await generate_questions_batch(
|
|
basis_item=basis_item,
|
|
target_level=request.target_level,
|
|
ai_model=request.ai_model,
|
|
count=request.count,
|
|
operator_notes=request.operator_notes,
|
|
)
|
|
item_ids: list[int] = []
|
|
response_items: list[AIBatchGeneratedItem] = []
|
|
for generated in generated_questions:
|
|
item_id = await save_ai_question(
|
|
generated_data=generated,
|
|
tryout_id=basis_item.tryout_id,
|
|
website_id=basis_item.website_id,
|
|
basis_item_id=basis_item.id,
|
|
slot=basis_item.slot,
|
|
level=request.target_level,
|
|
ai_model=request.ai_model,
|
|
db=db,
|
|
generation_run_id=run_id,
|
|
source_snapshot_question_id=basis_item.source_snapshot_question_id,
|
|
variant_status="active",
|
|
)
|
|
if item_id is not None:
|
|
item_ids.append(item_id)
|
|
response_items.append(
|
|
AIBatchGeneratedItem(
|
|
item_id=item_id,
|
|
stem=generated.stem,
|
|
options=generated.options,
|
|
correct=generated.correct,
|
|
explanation=generated.explanation,
|
|
level=request.target_level,
|
|
variant_status="active",
|
|
usage=generated.usage,
|
|
)
|
|
)
|
|
|
|
if not item_ids:
|
|
return AIGenerateBatchResponse(
|
|
success=False,
|
|
run_id=run_id,
|
|
generated_count=0,
|
|
error="AI generation failed. No variants were saved.",
|
|
)
|
|
|
|
return AIGenerateBatchResponse(
|
|
success=True,
|
|
run_id=run_id,
|
|
item_ids=item_ids,
|
|
items=response_items,
|
|
generated_count=len(item_ids),
|
|
usage=combine_usage([item.usage for item in response_items]),
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/stats",
|
|
response_model=AIStatsResponse,
|
|
summary="Get AI generation statistics",
|
|
description="""
|
|
Get statistics about AI-generated questions.
|
|
|
|
Returns:
|
|
- Total AI-generated items count
|
|
- Items count by model
|
|
- Cache hit rate (placeholder)
|
|
""",
|
|
)
|
|
async def get_stats(
|
|
db: Annotated[AsyncSession, Depends(get_db)],
|
|
auth: AuthContext = Depends(get_auth_context),
|
|
) -> AIStatsResponse:
|
|
"""
|
|
Get AI generation statistics.
|
|
"""
|
|
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
|
stats = await get_ai_stats(db, website_id=website_id)
|
|
|
|
return AIStatsResponse(
|
|
total_ai_items=stats["total_ai_items"],
|
|
items_by_model=stats["items_by_model"],
|
|
cache_hit_rate=stats["cache_hit_rate"],
|
|
total_cache_hits=stats["total_cache_hits"],
|
|
total_requests=stats["total_requests"],
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/models",
|
|
summary="List supported AI models",
|
|
description="Returns list of supported AI models for question generation.",
|
|
)
|
|
async def list_models(auth: AuthContext = Depends(get_auth_context)) -> dict:
|
|
"""
|
|
List supported AI models.
|
|
"""
|
|
require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
|
configured_models = [
|
|
{
|
|
"id": settings.OPENROUTER_MODEL_CHEAP,
|
|
"name": "Mistral Small 4",
|
|
"description": "Cheap and fast option for routine variant generation",
|
|
},
|
|
{
|
|
"id": settings.OPENROUTER_MODEL_QWEN,
|
|
"name": "Qwen 2.5 32B Instruct",
|
|
"description": "Balanced default for structured soal generation",
|
|
},
|
|
{
|
|
"id": settings.OPENROUTER_MODEL_LLAMA,
|
|
"name": "Llama 3.3 70B",
|
|
"description": "Premium fallback when you want better quality over cost",
|
|
},
|
|
]
|
|
|
|
models = []
|
|
for model in configured_models:
|
|
pricing = await get_model_pricing(model["id"])
|
|
models.append({**model, "pricing": pricing})
|
|
return {"models": models}
|
|
|
|
|
|
@router.get(
|
|
"/pending-reviews",
|
|
summary="Get pending AI generated questions",
|
|
description="Retrieve all AI generated questions that are pending review (variant_status='draft').",
|
|
)
|
|
async def admin_get_pending_reviews(
|
|
db: AsyncSession = Depends(get_db),
|
|
auth: AuthContext = Depends(get_auth_context),
|
|
) -> dict:
|
|
"""Retrieve pending reviews."""
|
|
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
|
|
|
query = (
|
|
select(Item)
|
|
.where(Item.generated_by == "ai", Item.variant_status == "draft")
|
|
.order_by(Item.created_at.desc())
|
|
.limit(200)
|
|
)
|
|
if website_id is not None:
|
|
query = query.where(Item.website_id == website_id)
|
|
|
|
result = await db.execute(query)
|
|
items = result.scalars().all()
|
|
|
|
return {
|
|
"items": [
|
|
{
|
|
"id": i.id,
|
|
"tryout_id": i.tryout_id,
|
|
"level": i.level,
|
|
"stem_text": i.stem_text if hasattr(i, 'stem_text') else i.stem[:100],
|
|
"ai_model": i.ai_model,
|
|
"basis_item_id": i.basis_item_id,
|
|
"created_at": i.created_at,
|
|
"status": i.variant_status,
|
|
}
|
|
for i in items
|
|
]
|
|
}
|
|
|
|
|
|
@router.post(
|
|
"/review/{item_id}",
|
|
summary="Approve or reject AI generated question",
|
|
description="Update the variant_status of an AI generated question.",
|
|
)
|
|
async def admin_review_ai_question(
|
|
item_id: int,
|
|
status: str, # "active", "rejected"
|
|
db: AsyncSession = Depends(get_db),
|
|
auth: AuthContext = Depends(get_auth_context),
|
|
) -> dict:
|
|
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
|
|
|
result = await db.execute(select(Item).where(Item.id == item_id))
|
|
item = result.scalar_one_or_none()
|
|
|
|
if not item:
|
|
raise HTTPException(status_code=404, detail="Item not found")
|
|
|
|
if website_id is not None and item.website_id != website_id:
|
|
raise HTTPException(status_code=403, detail="Not authorized for this website")
|
|
|
|
if status not in ["active", "rejected"]:
|
|
raise HTTPException(status_code=400, detail="Status must be active or rejected")
|
|
|
|
item.variant_status = status
|
|
await db.commit()
|
|
|
|
return {"success": True, "item_id": item_id, "status": status}
|