Harden auth and persist report schedules
This commit is contained in:
@@ -11,6 +11,7 @@ from typing import Literal, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
@@ -20,7 +21,7 @@ from app.core.auth import (
|
||||
get_auth_context,
|
||||
require_website_auth,
|
||||
)
|
||||
from app.models import Item, Session, Tryout
|
||||
from app.models import Item, Session, Tryout, UserAnswer
|
||||
from app.services.cat_selection import (
|
||||
CATSelectionError,
|
||||
get_next_item,
|
||||
@@ -65,9 +66,6 @@ class SubmitAnswerRequest(BaseModel):
|
||||
|
||||
class SubmitAnswerResponse(BaseModel):
|
||||
"""Response for submitting an answer."""
|
||||
is_correct: bool
|
||||
correct_answer: str
|
||||
explanation: Optional[str] = None
|
||||
theta: Optional[float] = None
|
||||
theta_se: Optional[float] = None
|
||||
|
||||
@@ -283,6 +281,18 @@ async def submit_answer_endpoint(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Item {request.item_id} not found"
|
||||
)
|
||||
|
||||
existing_answer_result = await db.execute(
|
||||
select(UserAnswer.id).where(
|
||||
UserAnswer.session_id == session_id,
|
||||
UserAnswer.item_id == request.item_id,
|
||||
)
|
||||
)
|
||||
if existing_answer_result.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Item was already answered for this session",
|
||||
)
|
||||
|
||||
# Check correctness
|
||||
is_correct = request.response.upper() == item.correct_answer.upper()
|
||||
@@ -290,9 +300,6 @@ async def submit_answer_endpoint(
|
||||
# Update theta
|
||||
theta, theta_se = await update_theta(db, session_id, request.item_id, is_correct)
|
||||
|
||||
# Create user answer record
|
||||
from app.models import UserAnswer
|
||||
|
||||
user_answer = UserAnswer(
|
||||
session_id=session_id,
|
||||
wp_user_id=session.wp_user_id,
|
||||
@@ -307,12 +314,15 @@ async def submit_answer_endpoint(
|
||||
)
|
||||
|
||||
db.add(user_answer)
|
||||
await db.commit()
|
||||
try:
|
||||
await db.commit()
|
||||
except IntegrityError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Item was already answered for this session",
|
||||
) from exc
|
||||
|
||||
return SubmitAnswerResponse(
|
||||
is_correct=is_correct,
|
||||
correct_answer=item.correct_answer,
|
||||
explanation=item.explanation,
|
||||
theta=theta,
|
||||
theta_se=theta_se
|
||||
)
|
||||
|
||||
@@ -4,14 +4,21 @@ Lightweight in-process rate limiting helpers.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
_lock = threading.Lock()
|
||||
_hits: dict[str, deque[float]] = defaultdict(deque)
|
||||
_redis_client: Redis | None = None
|
||||
_redis_unavailable = False
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
@@ -20,16 +27,26 @@ def _client_ip(request: Request) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def enforce_rate_limit(
|
||||
request: Request,
|
||||
def _get_redis_client() -> Redis | None:
|
||||
global _redis_client
|
||||
if _redis_unavailable:
|
||||
return None
|
||||
if _redis_client is None:
|
||||
settings = get_settings()
|
||||
if not settings.REDIS_URL:
|
||||
return None
|
||||
_redis_client = Redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _enforce_in_memory_rate_limit(
|
||||
*,
|
||||
key: str,
|
||||
scope: str,
|
||||
max_requests: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
now = time.time()
|
||||
ip = _client_ip(request)
|
||||
key = f"{scope}:{ip}"
|
||||
cutoff = now - window_seconds
|
||||
|
||||
with _lock:
|
||||
@@ -43,3 +60,62 @@ def enforce_rate_limit(
|
||||
)
|
||||
dq.append(now)
|
||||
|
||||
|
||||
async def enforce_rate_limit(
|
||||
request: Request,
|
||||
*,
|
||||
scope: str,
|
||||
max_requests: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
global _redis_unavailable
|
||||
|
||||
ip = _client_ip(request)
|
||||
key = f"{scope}:{ip}"
|
||||
|
||||
redis = _get_redis_client()
|
||||
if redis is not None:
|
||||
try:
|
||||
current = await redis.incr(key)
|
||||
if current == 1:
|
||||
await redis.expire(key, window_seconds)
|
||||
if current > max_requests:
|
||||
ttl = await redis.ttl(key)
|
||||
retry_after = ttl if ttl and ttl > 0 else window_seconds
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Too many requests for {scope}. Please try again later.",
|
||||
headers={"Retry-After": str(retry_after)},
|
||||
)
|
||||
return
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
_redis_unavailable = True
|
||||
logger.warning("Redis rate limiter unavailable; falling back to memory: %s", exc)
|
||||
|
||||
_enforce_in_memory_rate_limit(
|
||||
key=key,
|
||||
scope=scope,
|
||||
max_requests=max_requests,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
|
||||
async def close_rate_limit() -> None:
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
return
|
||||
try:
|
||||
await _redis_client.aclose()
|
||||
finally:
|
||||
_redis_client = None
|
||||
|
||||
|
||||
def reset_rate_limit_state() -> None:
|
||||
"""Reset local limiter state for tests."""
|
||||
global _redis_client, _redis_unavailable
|
||||
_redis_client = None
|
||||
_redis_unavailable = False
|
||||
with _lock:
|
||||
_hits.clear()
|
||||
|
||||
@@ -76,6 +76,9 @@ async def init_db() -> None:
|
||||
Note: In production, use Alembic migrations instead.
|
||||
This is useful for development and testing.
|
||||
"""
|
||||
if settings.ENVIRONMENT == "production":
|
||||
return
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.api.v1.session import (
|
||||
admin_router as adaptive_admin_router,
|
||||
router as adaptive_session_router,
|
||||
)
|
||||
from app.core.rate_limit import close_rate_limit
|
||||
from app.admin_web import (
|
||||
configure_admin_web,
|
||||
router as admin_web_router,
|
||||
@@ -86,6 +87,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# Shutdown: Close database connections
|
||||
if settings.ENABLE_ADMIN:
|
||||
await shutdown_admin_web()
|
||||
await close_rate_limit()
|
||||
await close_db()
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ Exports all SQLAlchemy ORM models for use in the application.
|
||||
from app.database import Base
|
||||
from app.models.ai_generation_run import AIGenerationRun
|
||||
from app.models.item import Item
|
||||
from app.models.report_schedule import ReportScheduleModel
|
||||
from app.models.session import Session
|
||||
from app.models.tryout import Tryout
|
||||
from app.models.tryout_import_snapshot import TryoutImportSnapshot
|
||||
@@ -25,6 +26,7 @@ __all__ = [
|
||||
"TryoutImportSnapshot",
|
||||
"TryoutSnapshotQuestion",
|
||||
"Item",
|
||||
"ReportScheduleModel",
|
||||
"Session",
|
||||
"UserAnswer",
|
||||
"TryoutStats",
|
||||
|
||||
46
app/models/report_schedule.py
Normal file
46
app/models/report_schedule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Persistent report schedule model.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, JSON, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ReportScheduleModel(Base):
|
||||
"""Database-backed report schedule configuration."""
|
||||
|
||||
__tablename__ = "report_schedules"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
schedule_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
comment="Public schedule identifier",
|
||||
)
|
||||
report_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
schedule: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
tryout_ids: Mapped[list[str]] = mapped_column(JSON, nullable=False)
|
||||
website_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("websites.id", ondelete="CASCADE", onupdate="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
recipients: Mapped[list[str]] = mapped_column(JSON, nullable=False)
|
||||
format: Mapped[str] = mapped_column(String(10), nullable=False, default="xlsx")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
last_run: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
next_run: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_report_schedules_website_active", "website_id", "is_active"),
|
||||
)
|
||||
@@ -5,12 +5,13 @@ Provides admin-specific endpoints for triggering calibration,
|
||||
toggling AI generation, and resetting normalization.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import AuthContext, get_auth_context, require_website_auth
|
||||
from app.core.config import get_settings
|
||||
from app.database import get_db
|
||||
from app.models import Tryout, TryoutStats
|
||||
@@ -23,35 +24,6 @@ router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def get_admin_website_id(
|
||||
x_website_id: Optional[str] = Header(None, alias="X-Website-ID"),
|
||||
) -> int:
|
||||
"""
|
||||
Extract and validate website_id from request header for admin operations.
|
||||
|
||||
Args:
|
||||
x_website_id: Website ID from header
|
||||
|
||||
Returns:
|
||||
Validated website ID as integer
|
||||
|
||||
Raises:
|
||||
HTTPException: If header is missing or invalid
|
||||
"""
|
||||
if x_website_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-Website-ID header is required",
|
||||
)
|
||||
try:
|
||||
return int(x_website_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-Website-ID must be a valid integer",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{tryout_id}/calibrate",
|
||||
summary="Trigger IRT calibration",
|
||||
@@ -60,7 +32,7 @@ def get_admin_website_id(
|
||||
async def admin_trigger_calibration(
|
||||
tryout_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
website_id: int = Depends(get_admin_website_id),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger IRT calibration for all items in a tryout.
|
||||
@@ -79,6 +51,8 @@ async def admin_trigger_calibration(
|
||||
Raises:
|
||||
HTTPException: If tryout not found or calibration fails
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
|
||||
# Verify tryout exists
|
||||
tryout_result = await db.execute(
|
||||
select(Tryout).where(
|
||||
@@ -121,7 +95,7 @@ async def admin_trigger_calibration(
|
||||
async def admin_toggle_ai_generation(
|
||||
tryout_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
website_id: int = Depends(get_admin_website_id),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Toggle AI generation for a tryout.
|
||||
@@ -139,6 +113,8 @@ async def admin_toggle_ai_generation(
|
||||
Raises:
|
||||
HTTPException: If tryout not found
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
|
||||
# Get tryout
|
||||
result = await db.execute(
|
||||
select(Tryout).where(
|
||||
@@ -175,7 +151,7 @@ async def admin_toggle_ai_generation(
|
||||
async def admin_reset_normalization(
|
||||
tryout_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
website_id: int = Depends(get_admin_website_id),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reset normalization for a tryout.
|
||||
@@ -193,6 +169,8 @@ async def admin_reset_normalization(
|
||||
Raises:
|
||||
HTTPException: If tryout or stats not found
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
|
||||
# Get tryout stats
|
||||
stats_result = await db.execute(
|
||||
select(TryoutStats).where(
|
||||
|
||||
@@ -78,7 +78,7 @@ async def generate_preview(
|
||||
- **ai_model**: OpenRouter model to use (default: qwen/qwen2.5-32b-instruct)
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request_http,
|
||||
scope="ai.generate_preview",
|
||||
max_requests=40,
|
||||
@@ -196,7 +196,7 @@ async def generate_save(
|
||||
- **ai_model**: AI model used for generation
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request_http,
|
||||
scope="ai.generate_save",
|
||||
max_requests=40,
|
||||
@@ -291,8 +291,8 @@ async def get_stats(
|
||||
"""
|
||||
Get AI generation statistics.
|
||||
"""
|
||||
require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
stats = await get_ai_stats(db)
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
stats = await get_ai_stats(db, website_id=website_id)
|
||||
|
||||
return AIStatsResponse(
|
||||
total_ai_items=stats["total_ai_items"],
|
||||
|
||||
@@ -77,7 +77,7 @@ async def preview_import(
|
||||
HTTPException: If file format is invalid or parsing fails
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="import.preview",
|
||||
max_requests=30,
|
||||
@@ -181,7 +181,7 @@ async def import_questions(
|
||||
HTTPException: If file format is invalid, validation fails, or import fails
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="import.questions",
|
||||
max_requests=20,
|
||||
@@ -351,7 +351,7 @@ async def preview_tryout_json(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="import.tryout_json_preview",
|
||||
max_requests=30,
|
||||
@@ -394,7 +394,7 @@ async def import_tryout_json(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="import.tryout_json",
|
||||
max_requests=20,
|
||||
|
||||
@@ -85,6 +85,15 @@ async def get_student_performance_report(
|
||||
Returns individual student records and/or aggregate statistics.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"})
|
||||
scoped_wp_user_id = None
|
||||
if auth.role == "student":
|
||||
if not auth.wp_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Student reports require an authenticated WordPress user",
|
||||
)
|
||||
scoped_wp_user_id = auth.wp_user_id
|
||||
|
||||
date_range = None
|
||||
if date_start or date_end:
|
||||
date_range = {}
|
||||
@@ -99,6 +108,7 @@ async def get_student_performance_report(
|
||||
db=db,
|
||||
date_range=date_range,
|
||||
format_type=format_type,
|
||||
wp_user_id=scoped_wp_user_id,
|
||||
)
|
||||
|
||||
return _convert_student_performance_report(report)
|
||||
@@ -361,7 +371,8 @@ async def create_report_schedule(
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
ensure_website_scope_matches(website_id, request.website_id)
|
||||
schedule_id = schedule_report(
|
||||
schedule_id = await schedule_report(
|
||||
db,
|
||||
report_type=request.report_type,
|
||||
schedule=request.schedule,
|
||||
tryout_ids=request.tryout_ids,
|
||||
@@ -370,7 +381,7 @@ async def create_report_schedule(
|
||||
export_format=request.export_format,
|
||||
)
|
||||
|
||||
scheduled = get_scheduled_report(schedule_id)
|
||||
scheduled = await get_scheduled_report(db, schedule_id)
|
||||
|
||||
return ReportScheduleResponse(
|
||||
schedule_id=schedule_id,
|
||||
@@ -387,6 +398,7 @@ async def create_report_schedule(
|
||||
)
|
||||
async def get_scheduled_report_details(
|
||||
schedule_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> ReportScheduleOutput:
|
||||
"""
|
||||
@@ -395,7 +407,7 @@ async def get_scheduled_report_details(
|
||||
Returns the configuration and status of a scheduled report.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
scheduled = get_scheduled_report(schedule_id)
|
||||
scheduled = await get_scheduled_report(db, schedule_id)
|
||||
|
||||
if not scheduled:
|
||||
raise HTTPException(
|
||||
@@ -431,6 +443,7 @@ async def get_scheduled_report_details(
|
||||
description="List all scheduled reports for a website.",
|
||||
)
|
||||
async def list_scheduled_reports_endpoint(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> List[ReportScheduleOutput]:
|
||||
"""
|
||||
@@ -439,7 +452,7 @@ async def list_scheduled_reports_endpoint(
|
||||
Returns all scheduled reports for the current website.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
reports = list_scheduled_reports(website_id=website_id)
|
||||
reports = await list_scheduled_reports(db, website_id=website_id)
|
||||
|
||||
return [
|
||||
ReportScheduleOutput(
|
||||
@@ -466,6 +479,7 @@ async def list_scheduled_reports_endpoint(
|
||||
)
|
||||
async def cancel_scheduled_report_endpoint(
|
||||
schedule_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -474,7 +488,7 @@ async def cancel_scheduled_report_endpoint(
|
||||
Removes the scheduled report from the system.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
scheduled = get_scheduled_report(schedule_id)
|
||||
scheduled = await get_scheduled_report(db, schedule_id)
|
||||
|
||||
if not scheduled:
|
||||
raise HTTPException(
|
||||
@@ -488,7 +502,7 @@ async def cancel_scheduled_report_endpoint(
|
||||
detail="Access denied to this scheduled report",
|
||||
)
|
||||
|
||||
success = cancel_scheduled_report(schedule_id)
|
||||
success = await cancel_scheduled_report(db, schedule_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
@@ -523,7 +537,7 @@ async def export_scheduled_report(
|
||||
Generates the report and returns it as a file download.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
scheduled = get_scheduled_report(schedule_id)
|
||||
scheduled = await get_scheduled_report(db, schedule_id)
|
||||
|
||||
if not scheduled:
|
||||
raise HTTPException(
|
||||
@@ -536,6 +550,11 @@ async def export_scheduled_report(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this scheduled report",
|
||||
)
|
||||
if not scheduled.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Scheduled report is inactive",
|
||||
)
|
||||
|
||||
# Generate report based on type
|
||||
report = None
|
||||
|
||||
@@ -10,6 +10,7 @@ Endpoints:
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
@@ -122,6 +123,27 @@ async def complete_session(
|
||||
items = {item.id: item for item in items_result.scalars().all()}
|
||||
|
||||
# Process each answer
|
||||
submitted_item_ids = [answer.item_id for answer in request.user_answers]
|
||||
if len(submitted_item_ids) != len(set(submitted_item_ids)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Duplicate item answers are not allowed in a session completion",
|
||||
)
|
||||
|
||||
existing_answers_result = await db.execute(
|
||||
select(UserAnswer.item_id).where(UserAnswer.session_id == session.session_id)
|
||||
)
|
||||
existing_answered_item_ids = {row[0] for row in existing_answers_result.all()}
|
||||
duplicate_existing_ids = sorted(set(submitted_item_ids) & existing_answered_item_ids)
|
||||
if duplicate_existing_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail={
|
||||
"message": "One or more items were already answered for this session",
|
||||
"item_ids": duplicate_existing_ids,
|
||||
},
|
||||
)
|
||||
|
||||
total_benar = 0
|
||||
total_bobot_earned = 0.0
|
||||
user_answer_records = []
|
||||
@@ -234,7 +256,13 @@ async def complete_session(
|
||||
await update_tryout_stats(db, website_id, session.tryout_id, nm)
|
||||
|
||||
# Commit all changes
|
||||
await db.commit()
|
||||
try:
|
||||
await db.commit()
|
||||
except IntegrityError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Duplicate item answer detected for this session",
|
||||
) from exc
|
||||
|
||||
# Refresh to get updated relationships
|
||||
await db.refresh(session)
|
||||
@@ -261,7 +289,6 @@ async def complete_session(
|
||||
id=ua.id,
|
||||
item_id=ua.item_id,
|
||||
response=ua.response,
|
||||
is_correct=ua.is_correct,
|
||||
time_spent=ua.time_spent,
|
||||
bobot_earned=ua.bobot_earned,
|
||||
scoring_mode_used=ua.scoring_mode_used,
|
||||
|
||||
@@ -15,7 +15,13 @@ from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.core.auth import issue_access_token
|
||||
from app.core.auth import (
|
||||
AuthContext,
|
||||
ensure_website_scope_matches,
|
||||
get_auth_context,
|
||||
issue_access_token,
|
||||
require_website_auth,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.models.website import Website
|
||||
from app.schemas.wordpress import (
|
||||
@@ -44,6 +50,16 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/wordpress", tags=["wordpress"])
|
||||
|
||||
|
||||
def _api_role_from_wordpress_roles(roles: list[str]) -> str:
|
||||
"""Map WordPress roles to API roles used by route authorization."""
|
||||
normalized_roles = {str(role).strip().lower() for role in roles}
|
||||
if normalized_roles & {"super_admin", "system_admin"}:
|
||||
return "system_admin"
|
||||
if normalized_roles & {"administrator", "admin"}:
|
||||
return "admin"
|
||||
return "student"
|
||||
|
||||
|
||||
def get_website_id_from_header(
|
||||
x_website_id: Optional[str] = Header(None, alias="X-Website-ID"),
|
||||
) -> int:
|
||||
@@ -132,7 +148,7 @@ async def sync_users_endpoint(
|
||||
Raises:
|
||||
HTTPException: If website not found, token invalid, or API error
|
||||
"""
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="wordpress.sync_users",
|
||||
max_requests=20,
|
||||
@@ -230,7 +246,7 @@ async def verify_session_endpoint(
|
||||
Raises:
|
||||
HTTPException: If website not found or API error
|
||||
"""
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
http_request,
|
||||
scope="wordpress.verify_session",
|
||||
max_requests=60,
|
||||
@@ -273,7 +289,7 @@ async def verify_session_endpoint(
|
||||
},
|
||||
access_token=issue_access_token(
|
||||
website_id=request.website_id,
|
||||
role="student",
|
||||
role=_api_role_from_wordpress_roles(wp_user_info.roles),
|
||||
wp_user_id=request.wp_user_id,
|
||||
expires_in_seconds=3600 * 24,
|
||||
),
|
||||
@@ -310,6 +326,7 @@ async def verify_session_endpoint(
|
||||
async def get_website_users(
|
||||
website_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> UserListResponse:
|
||||
@@ -328,6 +345,9 @@ async def get_website_users(
|
||||
Raises:
|
||||
HTTPException: If website not found
|
||||
"""
|
||||
auth_website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
ensure_website_scope_matches(auth_website_id, website_id)
|
||||
|
||||
# Validate website exists
|
||||
await get_valid_website(website_id, db)
|
||||
|
||||
@@ -374,6 +394,7 @@ async def get_user_endpoint(
|
||||
website_id: int,
|
||||
wp_user_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> WordPressUserResponse:
|
||||
"""
|
||||
Get a specific user by WordPress user ID.
|
||||
@@ -389,6 +410,16 @@ async def get_user_endpoint(
|
||||
Raises:
|
||||
HTTPException: If website or user not found
|
||||
"""
|
||||
auth_website_id = require_website_auth(
|
||||
auth, allowed_roles={"student", "admin", "system_admin"}
|
||||
)
|
||||
ensure_website_scope_matches(auth_website_id, website_id)
|
||||
if auth.role == "student" and auth.wp_user_id != wp_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User does not belong to this authenticated user",
|
||||
)
|
||||
|
||||
# Validate website exists
|
||||
await get_valid_website(website_id, db)
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ class UserAnswerOutput(BaseModel):
|
||||
id: int
|
||||
item_id: int
|
||||
response: str
|
||||
is_correct: bool
|
||||
time_spent: int
|
||||
bobot_earned: float
|
||||
scoring_mode_used: str
|
||||
@@ -37,6 +36,12 @@ class UserAnswerOutput(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserAnswerReviewOutput(UserAnswerOutput):
|
||||
"""Review output for a single answer."""
|
||||
|
||||
is_correct: bool
|
||||
|
||||
|
||||
class SessionCompleteResponse(BaseModel):
|
||||
"""Response schema for completed session with CTT scores."""
|
||||
|
||||
@@ -66,6 +71,12 @@ class SessionCompleteResponse(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SessionCompleteAdminResponse(SessionCompleteResponse):
|
||||
"""Completed session response with answer correctness for admin/review contexts."""
|
||||
|
||||
user_answers: List[UserAnswerReviewOutput]
|
||||
|
||||
|
||||
class SessionCreateRequest(BaseModel):
|
||||
"""Request schema for creating a new session."""
|
||||
|
||||
|
||||
@@ -715,7 +715,7 @@ async def generate_questions_batch(
|
||||
return generated_items
|
||||
|
||||
|
||||
async def get_ai_stats(db: AsyncSession) -> Dict[str, Any]:
|
||||
async def get_ai_stats(db: AsyncSession, website_id: int | None = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get AI generation statistics.
|
||||
|
||||
@@ -725,16 +725,18 @@ async def get_ai_stats(db: AsyncSession) -> Dict[str, Any]:
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
filters = [Item.generated_by == "ai"]
|
||||
if website_id is not None:
|
||||
filters.append(Item.website_id == website_id)
|
||||
|
||||
# Total AI-generated items
|
||||
total_result = await db.execute(
|
||||
select(func.count(Item.id)).where(Item.generated_by == "ai")
|
||||
)
|
||||
total_result = await db.execute(select(func.count(Item.id)).where(*filters))
|
||||
total_ai_items = total_result.scalar() or 0
|
||||
|
||||
# Items by model
|
||||
model_result = await db.execute(
|
||||
select(Item.ai_model, func.count(Item.id))
|
||||
.where(Item.generated_by == "ai")
|
||||
.where(*filters)
|
||||
.where(Item.ai_model.isnot(None))
|
||||
.group_by(Item.ai_model)
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.item import Item
|
||||
from app.models.report_schedule import ReportScheduleModel
|
||||
from app.models.session import Session
|
||||
from app.models.tryout import Tryout
|
||||
from app.models.tryout_stats import TryoutStats
|
||||
@@ -256,7 +257,8 @@ async def generate_student_performance_report(
|
||||
website_id: int,
|
||||
db: AsyncSession,
|
||||
date_range: Optional[Dict[str, datetime]] = None,
|
||||
format_type: Literal["individual", "aggregate", "both"] = "both"
|
||||
format_type: Literal["individual", "aggregate", "both"] = "both",
|
||||
wp_user_id: Optional[str] = None,
|
||||
) -> StudentPerformanceReport:
|
||||
"""
|
||||
Generate student performance report.
|
||||
@@ -267,6 +269,7 @@ async def generate_student_performance_report(
|
||||
db: Database session
|
||||
date_range: Optional date range filter {"start": datetime, "end": datetime}
|
||||
format_type: Report format - individual, aggregate, or both
|
||||
wp_user_id: Optional WordPress user filter for student-scoped reports
|
||||
|
||||
Returns:
|
||||
StudentPerformanceReport with aggregate stats and/or individual records
|
||||
@@ -287,6 +290,9 @@ async def generate_student_performance_report(
|
||||
query = query.where(Session.start_time >= date_range["start"])
|
||||
if date_range.get("end"):
|
||||
query = query.where(Session.start_time <= date_range["end"])
|
||||
|
||||
if wp_user_id is not None:
|
||||
query = query.where(Session.wp_user_id == wp_user_id)
|
||||
|
||||
query = query.order_by(Session.NN.desc().nullslast())
|
||||
|
||||
@@ -1382,11 +1388,34 @@ class ReportSchedule:
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
# In-memory store for scheduled reports (in production, use database)
|
||||
_scheduled_reports: Dict[str, ReportSchedule] = {}
|
||||
def _calculate_next_run(schedule: Literal["daily", "weekly", "monthly"]) -> datetime:
|
||||
now = datetime.now(timezone.utc)
|
||||
if schedule == "daily":
|
||||
return now + timedelta(days=1)
|
||||
if schedule == "weekly":
|
||||
return now + timedelta(weeks=1)
|
||||
return now + timedelta(days=30)
|
||||
|
||||
|
||||
def schedule_report(
|
||||
def _schedule_from_model(row: ReportScheduleModel) -> ReportSchedule:
|
||||
return ReportSchedule(
|
||||
schedule_id=row.schedule_id,
|
||||
report_type=row.report_type,
|
||||
schedule=row.schedule,
|
||||
tryout_ids=list(row.tryout_ids or []),
|
||||
website_id=row.website_id,
|
||||
recipients=list(row.recipients or []),
|
||||
format=row.format,
|
||||
created_at=row.created_at,
|
||||
last_run=row.last_run,
|
||||
next_run=row.next_run,
|
||||
is_active=row.is_active,
|
||||
)
|
||||
|
||||
|
||||
async def schedule_report(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
report_type: Literal["student_performance", "item_analysis", "calibration_status", "tryout_comparison"],
|
||||
schedule: Literal["daily", "weekly", "monthly"],
|
||||
tryout_ids: List[str],
|
||||
@@ -1412,16 +1441,7 @@ def schedule_report(
|
||||
|
||||
schedule_id = str(uuid.uuid4())
|
||||
|
||||
# Calculate next run time
|
||||
now = datetime.now(timezone.utc)
|
||||
if schedule == "daily":
|
||||
next_run = now + timedelta(days=1)
|
||||
elif schedule == "weekly":
|
||||
next_run = now + timedelta(weeks=1)
|
||||
else: # monthly
|
||||
next_run = now + timedelta(days=30)
|
||||
|
||||
report_schedule = ReportSchedule(
|
||||
report_schedule = ReportScheduleModel(
|
||||
schedule_id=schedule_id,
|
||||
report_type=report_type,
|
||||
schedule=schedule,
|
||||
@@ -1429,35 +1449,54 @@ def schedule_report(
|
||||
website_id=website_id,
|
||||
recipients=recipients,
|
||||
format=export_format,
|
||||
next_run=next_run,
|
||||
next_run=_calculate_next_run(schedule),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
_scheduled_reports[schedule_id] = report_schedule
|
||||
db.add(report_schedule)
|
||||
await db.flush()
|
||||
logger.info(f"Scheduled report {schedule_id}: {report_type} {schedule}")
|
||||
|
||||
return schedule_id
|
||||
|
||||
|
||||
def get_scheduled_report(schedule_id: str) -> Optional[ReportSchedule]:
|
||||
async def get_scheduled_report(db: AsyncSession, schedule_id: str) -> Optional[ReportSchedule]:
|
||||
"""Get a scheduled report by ID."""
|
||||
return _scheduled_reports.get(schedule_id)
|
||||
result = await db.execute(
|
||||
select(ReportScheduleModel).where(ReportScheduleModel.schedule_id == schedule_id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return _schedule_from_model(row) if row else None
|
||||
|
||||
|
||||
def list_scheduled_reports(website_id: Optional[int] = None) -> List[ReportSchedule]:
|
||||
async def list_scheduled_reports(
|
||||
db: AsyncSession,
|
||||
website_id: Optional[int] = None,
|
||||
) -> List[ReportSchedule]:
|
||||
"""List all scheduled reports, optionally filtered by website."""
|
||||
reports = list(_scheduled_reports.values())
|
||||
if website_id:
|
||||
reports = [r for r in reports if r.website_id == website_id]
|
||||
return reports
|
||||
query = (
|
||||
select(ReportScheduleModel)
|
||||
.where(ReportScheduleModel.is_active == True)
|
||||
.order_by(ReportScheduleModel.created_at.desc())
|
||||
)
|
||||
if website_id is not None:
|
||||
query = query.where(ReportScheduleModel.website_id == website_id)
|
||||
result = await db.execute(query)
|
||||
return [_schedule_from_model(row) for row in result.scalars().all()]
|
||||
|
||||
|
||||
def cancel_scheduled_report(schedule_id: str) -> bool:
|
||||
async def cancel_scheduled_report(db: AsyncSession, schedule_id: str) -> bool:
|
||||
"""Cancel a scheduled report."""
|
||||
if schedule_id in _scheduled_reports:
|
||||
del _scheduled_reports[schedule_id]
|
||||
logger.info(f"Cancelled scheduled report {schedule_id}")
|
||||
return True
|
||||
return False
|
||||
result = await db.execute(
|
||||
select(ReportScheduleModel).where(ReportScheduleModel.schedule_id == schedule_id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return False
|
||||
row.is_active = False
|
||||
await db.flush()
|
||||
logger.info(f"Cancelled scheduled report {schedule_id}")
|
||||
return True
|
||||
|
||||
|
||||
# Export public API
|
||||
|
||||
Reference in New Issue
Block a user