459 lines
13 KiB
Python
459 lines
13 KiB
Python
"""
|
|
Tryout API router for tryout configuration and management.
|
|
|
|
Endpoints:
|
|
- GET /tryout/{tryout_id}/config: Get tryout configuration
|
|
- PUT /tryout/{tryout_id}/normalization: Update normalization settings
|
|
- GET /tryout: List tryouts for a website
|
|
"""
|
|
|
|
from typing import List, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Header, status
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.database import get_db
|
|
from app.models.item import Item
|
|
from app.models.tryout import Tryout
|
|
from app.models.tryout_stats import TryoutStats
|
|
from app.schemas.tryout import (
|
|
NormalizationUpdateRequest,
|
|
NormalizationUpdateResponse,
|
|
TryoutConfigBrief,
|
|
TryoutConfigResponse,
|
|
TryoutStatsResponse,
|
|
)
|
|
|
|
router = APIRouter(prefix="/tryout", tags=["tryouts"])
|
|
|
|
|
|
def get_website_id_from_header(
|
|
x_website_id: Optional[str] = Header(None, alias="X-Website-ID"),
|
|
) -> int:
|
|
"""
|
|
Extract and validate website_id from request header.
|
|
|
|
Args:
|
|
x_website_id: Website ID from header
|
|
|
|
Returns:
|
|
Validated website ID as integer
|
|
|
|
Raises:
|
|
HTTPException: If header is missing or invalid
|
|
"""
|
|
if x_website_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="X-Website-ID header is required",
|
|
)
|
|
try:
|
|
return int(x_website_id)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="X-Website-ID must be a valid integer",
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/{tryout_id}/config",
|
|
response_model=TryoutConfigResponse,
|
|
summary="Get tryout configuration",
|
|
description="Retrieve tryout configuration including scoring mode, normalization settings, and current stats.",
|
|
)
|
|
async def get_tryout_config(
|
|
tryout_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
website_id: int = Depends(get_website_id_from_header),
|
|
) -> TryoutConfigResponse:
|
|
"""
|
|
Get tryout configuration.
|
|
|
|
Returns:
|
|
TryoutConfigResponse with scoring_mode, normalization_mode, and current_stats
|
|
|
|
Raises:
|
|
HTTPException: If tryout not found
|
|
"""
|
|
# Get tryout with stats
|
|
result = await db.execute(
|
|
select(Tryout)
|
|
.options(selectinload(Tryout.stats))
|
|
.where(
|
|
Tryout.website_id == website_id,
|
|
Tryout.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
tryout = result.scalar_one_or_none()
|
|
|
|
if tryout is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Tryout {tryout_id} not found for website {website_id}",
|
|
)
|
|
|
|
# Build stats response
|
|
current_stats = None
|
|
if tryout.stats:
|
|
current_stats = TryoutStatsResponse(
|
|
participant_count=tryout.stats.participant_count,
|
|
rataan=tryout.stats.rataan,
|
|
sb=tryout.stats.sb,
|
|
min_nm=tryout.stats.min_nm,
|
|
max_nm=tryout.stats.max_nm,
|
|
last_calculated=tryout.stats.last_calculated,
|
|
)
|
|
|
|
return TryoutConfigResponse(
|
|
id=tryout.id,
|
|
website_id=tryout.website_id,
|
|
tryout_id=tryout.tryout_id,
|
|
name=tryout.name,
|
|
description=tryout.description,
|
|
scoring_mode=tryout.scoring_mode,
|
|
selection_mode=tryout.selection_mode,
|
|
normalization_mode=tryout.normalization_mode,
|
|
min_sample_for_dynamic=tryout.min_sample_for_dynamic,
|
|
static_rataan=tryout.static_rataan,
|
|
static_sb=tryout.static_sb,
|
|
ai_generation_enabled=tryout.ai_generation_enabled,
|
|
hybrid_transition_slot=tryout.hybrid_transition_slot,
|
|
min_calibration_sample=tryout.min_calibration_sample,
|
|
theta_estimation_method=tryout.theta_estimation_method,
|
|
fallback_to_ctt_on_error=tryout.fallback_to_ctt_on_error,
|
|
current_stats=current_stats,
|
|
created_at=tryout.created_at,
|
|
updated_at=tryout.updated_at,
|
|
)
|
|
|
|
|
|
@router.put(
|
|
"/{tryout_id}/normalization",
|
|
response_model=NormalizationUpdateResponse,
|
|
summary="Update normalization settings",
|
|
description="Update normalization mode and static values for a tryout.",
|
|
)
|
|
async def update_normalization(
|
|
tryout_id: str,
|
|
request: NormalizationUpdateRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
website_id: int = Depends(get_website_id_from_header),
|
|
) -> NormalizationUpdateResponse:
|
|
"""
|
|
Update normalization settings for a tryout.
|
|
|
|
Args:
|
|
tryout_id: Tryout identifier
|
|
request: Normalization update request
|
|
db: Database session
|
|
website_id: Website ID from header
|
|
|
|
Returns:
|
|
NormalizationUpdateResponse with updated settings
|
|
|
|
Raises:
|
|
HTTPException: If tryout not found or validation fails
|
|
"""
|
|
# Get tryout
|
|
result = await db.execute(
|
|
select(Tryout).where(
|
|
Tryout.website_id == website_id,
|
|
Tryout.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
tryout = result.scalar_one_or_none()
|
|
|
|
if tryout is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Tryout {tryout_id} not found for website {website_id}",
|
|
)
|
|
|
|
# Update normalization mode if provided
|
|
if request.normalization_mode is not None:
|
|
tryout.normalization_mode = request.normalization_mode
|
|
|
|
# Update static values if provided
|
|
if request.static_rataan is not None:
|
|
tryout.static_rataan = request.static_rataan
|
|
|
|
if request.static_sb is not None:
|
|
tryout.static_sb = request.static_sb
|
|
|
|
# Get current stats for participant count
|
|
stats_result = await db.execute(
|
|
select(TryoutStats).where(
|
|
TryoutStats.website_id == website_id,
|
|
TryoutStats.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
stats = stats_result.scalar_one_or_none()
|
|
current_participant_count = stats.participant_count if stats else 0
|
|
|
|
await db.commit()
|
|
await db.refresh(tryout)
|
|
|
|
return NormalizationUpdateResponse(
|
|
tryout_id=tryout.tryout_id,
|
|
normalization_mode=tryout.normalization_mode,
|
|
static_rataan=tryout.static_rataan,
|
|
static_sb=tryout.static_sb,
|
|
will_switch_to_dynamic_at=tryout.min_sample_for_dynamic,
|
|
current_participant_count=current_participant_count,
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/",
|
|
response_model=List[TryoutConfigBrief],
|
|
summary="List tryouts",
|
|
description="List all tryouts for a website.",
|
|
)
|
|
async def list_tryouts(
|
|
db: AsyncSession = Depends(get_db),
|
|
website_id: int = Depends(get_website_id_from_header),
|
|
) -> List[TryoutConfigBrief]:
|
|
"""
|
|
List all tryouts for a website.
|
|
|
|
Args:
|
|
db: Database session
|
|
website_id: Website ID from header
|
|
|
|
Returns:
|
|
List of TryoutConfigBrief
|
|
"""
|
|
# Get tryouts with stats
|
|
result = await db.execute(
|
|
select(Tryout)
|
|
.options(selectinload(Tryout.stats))
|
|
.where(Tryout.website_id == website_id)
|
|
)
|
|
tryouts = result.scalars().all()
|
|
|
|
return [
|
|
TryoutConfigBrief(
|
|
tryout_id=t.tryout_id,
|
|
name=t.name,
|
|
scoring_mode=t.scoring_mode,
|
|
selection_mode=t.selection_mode,
|
|
normalization_mode=t.normalization_mode,
|
|
participant_count=t.stats.participant_count if t.stats else 0,
|
|
)
|
|
for t in tryouts
|
|
]
|
|
|
|
|
|
@router.get(
|
|
"/{tryout_id}/calibration-status",
|
|
summary="Get calibration status",
|
|
description="Get IRT calibration status for items in this tryout.",
|
|
)
|
|
async def get_calibration_status(
|
|
tryout_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
website_id: int = Depends(get_website_id_from_header),
|
|
):
|
|
"""
|
|
Get calibration status for items in a tryout.
|
|
|
|
Returns statistics on how many items are calibrated and ready for IRT.
|
|
|
|
Args:
|
|
tryout_id: Tryout identifier
|
|
db: Database session
|
|
website_id: Website ID from header
|
|
|
|
Returns:
|
|
Calibration status summary
|
|
|
|
Raises:
|
|
HTTPException: If tryout not found
|
|
"""
|
|
# Verify tryout exists
|
|
tryout_result = await db.execute(
|
|
select(Tryout).where(
|
|
Tryout.website_id == website_id,
|
|
Tryout.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
tryout = tryout_result.scalar_one_or_none()
|
|
|
|
if tryout is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Tryout {tryout_id} not found for website {website_id}",
|
|
)
|
|
|
|
# Get calibration statistics
|
|
stats_result = await db.execute(
|
|
select(
|
|
func.count().label("total_items"),
|
|
func.sum(func.cast(Item.calibrated, type_=func.INTEGER)).label("calibrated_items"),
|
|
func.avg(Item.calibration_sample_size).label("avg_sample_size"),
|
|
).where(
|
|
Item.website_id == website_id,
|
|
Item.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
stats = stats_result.first()
|
|
|
|
total_items = stats.total_items or 0
|
|
calibrated_items = stats.calibrated_items or 0
|
|
calibration_percentage = (calibrated_items / total_items * 100) if total_items > 0 else 0
|
|
|
|
return {
|
|
"tryout_id": tryout_id,
|
|
"total_items": total_items,
|
|
"calibrated_items": calibrated_items,
|
|
"calibration_percentage": round(calibration_percentage, 2),
|
|
"avg_sample_size": round(stats.avg_sample_size, 2) if stats.avg_sample_size else 0,
|
|
"min_calibration_sample": tryout.min_calibration_sample,
|
|
"ready_for_irt": calibration_percentage >= 90,
|
|
}
|
|
|
|
|
|
@router.post(
|
|
"/{tryout_id}/calibrate",
|
|
summary="Trigger IRT calibration",
|
|
description="Trigger IRT calibration for all items in this tryout with sufficient response data.",
|
|
)
|
|
async def trigger_calibration(
|
|
tryout_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
website_id: int = Depends(get_website_id_from_header),
|
|
):
|
|
"""
|
|
Trigger IRT calibration for all items in a tryout.
|
|
|
|
Runs calibration for items with >= min_calibration_sample responses.
|
|
Updates item.irt_b, item.irt_se, and item.calibrated status.
|
|
|
|
Args:
|
|
tryout_id: Tryout identifier
|
|
db: Database session
|
|
website_id: Website ID from header
|
|
|
|
Returns:
|
|
Calibration results summary
|
|
|
|
Raises:
|
|
HTTPException: If tryout not found or calibration fails
|
|
"""
|
|
from app.services.irt_calibration import (
|
|
calibrate_all,
|
|
CALIBRATION_SAMPLE_THRESHOLD,
|
|
)
|
|
|
|
# Verify tryout exists
|
|
tryout_result = await db.execute(
|
|
select(Tryout).where(
|
|
Tryout.website_id == website_id,
|
|
Tryout.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
tryout = tryout_result.scalar_one_or_none()
|
|
|
|
if tryout is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Tryout {tryout_id} not found for website {website_id}",
|
|
)
|
|
|
|
# Run calibration
|
|
result = await calibrate_all(
|
|
tryout_id=tryout_id,
|
|
website_id=website_id,
|
|
db=db,
|
|
min_sample_size=tryout.min_calibration_sample or CALIBRATION_SAMPLE_THRESHOLD,
|
|
)
|
|
|
|
return {
|
|
"tryout_id": tryout_id,
|
|
"total_items": result.total_items,
|
|
"calibrated_items": result.calibrated_items,
|
|
"failed_items": result.failed_items,
|
|
"calibration_percentage": round(result.calibration_percentage * 100, 2),
|
|
"ready_for_irt": result.ready_for_irt,
|
|
"message": f"Calibration complete: {result.calibrated_items}/{result.total_items} items calibrated",
|
|
}
|
|
|
|
|
|
@router.post(
|
|
"/{tryout_id}/calibrate/{item_id}",
|
|
summary="Trigger IRT calibration for single item",
|
|
description="Trigger IRT calibration for a specific item.",
|
|
)
|
|
async def trigger_item_calibration(
|
|
tryout_id: str,
|
|
item_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
website_id: int = Depends(get_website_id_from_header),
|
|
):
|
|
"""
|
|
Trigger IRT calibration for a single item.
|
|
|
|
Args:
|
|
tryout_id: Tryout identifier
|
|
item_id: Item ID to calibrate
|
|
db: Database session
|
|
website_id: Website ID from header
|
|
|
|
Returns:
|
|
Calibration result for the item
|
|
|
|
Raises:
|
|
HTTPException: If tryout or item not found
|
|
"""
|
|
from app.services.irt_calibration import calibrate_item, CALIBRATION_SAMPLE_THRESHOLD
|
|
|
|
# Verify tryout exists
|
|
tryout_result = await db.execute(
|
|
select(Tryout).where(
|
|
Tryout.website_id == website_id,
|
|
Tryout.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
tryout = tryout_result.scalar_one_or_none()
|
|
|
|
if tryout is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Tryout {tryout_id} not found for website {website_id}",
|
|
)
|
|
|
|
# Verify item belongs to this tryout
|
|
item_result = await db.execute(
|
|
select(Item).where(
|
|
Item.id == item_id,
|
|
Item.website_id == website_id,
|
|
Item.tryout_id == tryout_id,
|
|
)
|
|
)
|
|
item = item_result.scalar_one_or_none()
|
|
|
|
if item is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Item {item_id} not found in tryout {tryout_id}",
|
|
)
|
|
|
|
# Run calibration
|
|
result = await calibrate_item(
|
|
item_id=item_id,
|
|
db=db,
|
|
min_sample_size=tryout.min_calibration_sample or CALIBRATION_SAMPLE_THRESHOLD,
|
|
)
|
|
|
|
return {
|
|
"item_id": result.item_id,
|
|
"status": result.status.value,
|
|
"irt_b": result.irt_b,
|
|
"irt_se": result.irt_se,
|
|
"sample_size": result.sample_size,
|
|
"message": result.message,
|
|
}
|