250 lines
6.9 KiB
Python
250 lines
6.9 KiB
Python
"""
|
|
Admin API router for custom admin actions.
|
|
|
|
Provides admin-specific endpoints for triggering calibration,
|
|
toggling AI generation, and resetting normalization.
|
|
"""
|
|
|
|
from typing import Dict, Optional
|
|
|
|
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import get_settings
|
|
from app.database import get_db
|
|
from app.models import Tryout, TryoutStats
|
|
from app.services.irt_calibration import (
|
|
calibrate_all,
|
|
CALIBRATION_SAMPLE_THRESHOLD,
|
|
)
|
|
|
|
router = APIRouter(prefix="/admin", tags=["admin"])
|
|
settings = get_settings()
|
|
|
|
|
|
def get_admin_website_id(
|
|
x_website_id: Optional[str] = Header(None, alias="X-Website-ID"),
|
|
) -> int:
|
|
"""
|
|
Extract and validate website_id from request header for admin operations.
|
|
|
|
Args:
|
|
x_website_id: Website ID from header
|
|
|
|
Returns:
|
|
Validated website ID as integer
|
|
|
|
Raises:
|
|
HTTPException: If header is missing or invalid
|
|
"""
|
|
if x_website_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="X-Website-ID header is required",
|
|
)
|
|
try:
|
|
return int(x_website_id)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="X-Website-ID must be a valid integer",
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/{tryout_id}/calibrate",
|
|
summary="Trigger IRT calibration",
|
|
description="Trigger IRT calibration for all items in this tryout with sufficient response data.",
|
|
)
|
|
async def admin_trigger_calibration(
|
|
tryout_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
website_id: int = Depends(get_admin_website_id),
|
|
) -> Dict[str, any]:
|
|
"""
|
|
Trigger IRT calibration for all items in a tryout.
|
|
|
|
Runs calibration for items with >= min_calibration_sample responses.
|
|
Updates item.irt_b, item.irt_se, and item.calibrated status.
|
|
|
|
Args:
|
|
tryout_id: Tryout identifier
|
|
db: Database session
|
|
website_id: Website ID from header
|
|
|
|
Returns:
|
|
Calibration results summary
|
|
|
|
Raises:
|
|
HTTPException: If tryout not found or calibration fails
|
|
"""
|
|
# Verify tryout exists
|
|
tryout_result = await db.execute(
|
|
select(Tryout).where(
|
|
Tryout.website_id == website_id,
|
|
Tryout.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
tryout = tryout_result.scalar_one_or_none()
|
|
|
|
if tryout is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Tryout {tryout_id} not found for website {website_id}",
|
|
)
|
|
|
|
# Run calibration
|
|
result = await calibrate_all(
|
|
tryout_id=tryout_id,
|
|
website_id=website_id,
|
|
db=db,
|
|
min_sample_size=tryout.min_calibration_sample or CALIBRATION_SAMPLE_THRESHOLD,
|
|
)
|
|
|
|
return {
|
|
"tryout_id": tryout_id,
|
|
"total_items": result.total_items,
|
|
"calibrated_items": result.calibrated_items,
|
|
"failed_items": result.failed_items,
|
|
"calibration_percentage": round(result.calibration_percentage * 100, 2),
|
|
"ready_for_irt": result.ready_for_irt,
|
|
"message": f"Calibration complete: {result.calibrated_items}/{result.total_items} items calibrated",
|
|
}
|
|
|
|
|
|
@router.post(
|
|
"/{tryout_id}/toggle-ai-generation",
|
|
summary="Toggle AI generation",
|
|
description="Toggle AI question generation for a tryout.",
|
|
)
|
|
async def admin_toggle_ai_generation(
|
|
tryout_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
website_id: int = Depends(get_admin_website_id),
|
|
) -> Dict[str, any]:
|
|
"""
|
|
Toggle AI generation for a tryout.
|
|
|
|
Updates Tryout.AI_generation_enabled field.
|
|
|
|
Args:
|
|
tryout_id: Tryout identifier
|
|
db: Database session
|
|
website_id: Website ID from header
|
|
|
|
Returns:
|
|
Updated AI generation status
|
|
|
|
Raises:
|
|
HTTPException: If tryout not found
|
|
"""
|
|
# Get tryout
|
|
result = await db.execute(
|
|
select(Tryout).where(
|
|
Tryout.website_id == website_id,
|
|
Tryout.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
tryout = result.scalar_one_or_none()
|
|
|
|
if tryout is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Tryout {tryout_id} not found for website {website_id}",
|
|
)
|
|
|
|
# Toggle AI generation
|
|
tryout.ai_generation_enabled = not tryout.ai_generation_enabled
|
|
await db.commit()
|
|
await db.refresh(tryout)
|
|
|
|
status = "enabled" if tryout.ai_generation_enabled else "disabled"
|
|
return {
|
|
"tryout_id": tryout_id,
|
|
"ai_generation_enabled": tryout.ai_generation_enabled,
|
|
"message": f"AI generation {status} for tryout {tryout_id}",
|
|
}
|
|
|
|
|
|
@router.post(
|
|
"/{tryout_id}/reset-normalization",
|
|
summary="Reset normalization",
|
|
description="Reset normalization to static values and clear incremental stats.",
|
|
)
|
|
async def admin_reset_normalization(
|
|
tryout_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
website_id: int = Depends(get_admin_website_id),
|
|
) -> Dict[str, any]:
|
|
"""
|
|
Reset normalization for a tryout.
|
|
|
|
Resets rataan, sb to static values and clears incremental stats.
|
|
|
|
Args:
|
|
tryout_id: Tryout identifier
|
|
db: Database session
|
|
website_id: Website ID from header
|
|
|
|
Returns:
|
|
Reset statistics
|
|
|
|
Raises:
|
|
HTTPException: If tryout or stats not found
|
|
"""
|
|
# Get tryout stats
|
|
stats_result = await db.execute(
|
|
select(TryoutStats).where(
|
|
TryoutStats.website_id == website_id,
|
|
TryoutStats.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
stats = stats_result.scalar_one_or_none()
|
|
|
|
if stats is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"TryoutStats for {tryout_id} not found for website {website_id}",
|
|
)
|
|
|
|
# Get tryout for static values
|
|
tryout_result = await db.execute(
|
|
select(Tryout).where(
|
|
Tryout.website_id == website_id,
|
|
Tryout.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
tryout = tryout_result.scalar_one_or_none()
|
|
|
|
if tryout:
|
|
# Reset to static values
|
|
stats.rataan = tryout.static_rataan
|
|
stats.sb = tryout.static_sb
|
|
else:
|
|
# Reset to default values
|
|
stats.rataan = 500.0
|
|
stats.sb = 100.0
|
|
|
|
# Clear incremental stats
|
|
old_participant_count = stats.participant_count
|
|
stats.participant_count = 0
|
|
stats.total_nm_sum = 0.0
|
|
stats.total_nm_sq_sum = 0.0
|
|
stats.min_nm = None
|
|
stats.max_nm = None
|
|
stats.last_calculated = None
|
|
|
|
await db.commit()
|
|
await db.refresh(stats)
|
|
|
|
return {
|
|
"tryout_id": tryout_id,
|
|
"rataan": stats.rataan,
|
|
"sb": stats.sb,
|
|
"cleared_stats": {
|
|
"previous_participant_count": old_participant_count,
|
|
},
|
|
"message": f"Normalization reset to static values (rataan={stats.rataan}, sb={stats.sb}). Incremental stats cleared.",
|
|
}
|