Harden auth and persist report schedules
This commit is contained in:
53
alembic/versions/20260405_000004_report_schedules.py
Normal file
53
alembic/versions/20260405_000004_report_schedules.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""add persistent report schedules
|
||||
|
||||
Revision ID: 20260405_000004
|
||||
Revises: 20260404_000003
|
||||
Create Date: 2026-04-05 09:00:00
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "20260405_000004"
|
||||
down_revision: Union[str, None] = "20260404_000003"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"report_schedules",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("schedule_id", sa.String(length=36), nullable=False),
|
||||
sa.Column("report_type", sa.String(length=50), nullable=False),
|
||||
sa.Column("schedule", sa.String(length=20), nullable=False),
|
||||
sa.Column("tryout_ids", sa.JSON(), nullable=False),
|
||||
sa.Column("website_id", sa.Integer(), nullable=False),
|
||||
sa.Column("recipients", sa.JSON(), nullable=False),
|
||||
sa.Column("format", sa.String(length=10), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("last_run", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("next_run", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["website_id"], ["websites.id"], ondelete="CASCADE", onupdate="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("schedule_id"),
|
||||
)
|
||||
op.create_index("ix_report_schedules_schedule_id", "report_schedules", ["schedule_id"], unique=True)
|
||||
op.create_index("ix_report_schedules_website_id", "report_schedules", ["website_id"], unique=False)
|
||||
op.create_index(
|
||||
"ix_report_schedules_website_active",
|
||||
"report_schedules",
|
||||
["website_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_report_schedules_website_active", table_name="report_schedules")
|
||||
op.drop_index("ix_report_schedules_website_id", table_name="report_schedules")
|
||||
op.drop_index("ix_report_schedules_schedule_id", table_name="report_schedules")
|
||||
op.drop_table("report_schedules")
|
||||
@@ -11,6 +11,7 @@ from typing import Literal, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
@@ -20,7 +21,7 @@ from app.core.auth import (
|
||||
get_auth_context,
|
||||
require_website_auth,
|
||||
)
|
||||
from app.models import Item, Session, Tryout
|
||||
from app.models import Item, Session, Tryout, UserAnswer
|
||||
from app.services.cat_selection import (
|
||||
CATSelectionError,
|
||||
get_next_item,
|
||||
@@ -65,9 +66,6 @@ class SubmitAnswerRequest(BaseModel):
|
||||
|
||||
class SubmitAnswerResponse(BaseModel):
|
||||
"""Response for submitting an answer."""
|
||||
is_correct: bool
|
||||
correct_answer: str
|
||||
explanation: Optional[str] = None
|
||||
theta: Optional[float] = None
|
||||
theta_se: Optional[float] = None
|
||||
|
||||
@@ -284,15 +282,24 @@ async def submit_answer_endpoint(
|
||||
detail=f"Item {request.item_id} not found"
|
||||
)
|
||||
|
||||
existing_answer_result = await db.execute(
|
||||
select(UserAnswer.id).where(
|
||||
UserAnswer.session_id == session_id,
|
||||
UserAnswer.item_id == request.item_id,
|
||||
)
|
||||
)
|
||||
if existing_answer_result.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Item was already answered for this session",
|
||||
)
|
||||
|
||||
# Check correctness
|
||||
is_correct = request.response.upper() == item.correct_answer.upper()
|
||||
|
||||
# Update theta
|
||||
theta, theta_se = await update_theta(db, session_id, request.item_id, is_correct)
|
||||
|
||||
# Create user answer record
|
||||
from app.models import UserAnswer
|
||||
|
||||
user_answer = UserAnswer(
|
||||
session_id=session_id,
|
||||
wp_user_id=session.wp_user_id,
|
||||
@@ -307,12 +314,15 @@ async def submit_answer_endpoint(
|
||||
)
|
||||
|
||||
db.add(user_answer)
|
||||
await db.commit()
|
||||
try:
|
||||
await db.commit()
|
||||
except IntegrityError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Item was already answered for this session",
|
||||
) from exc
|
||||
|
||||
return SubmitAnswerResponse(
|
||||
is_correct=is_correct,
|
||||
correct_answer=item.correct_answer,
|
||||
explanation=item.explanation,
|
||||
theta=theta,
|
||||
theta_se=theta_se
|
||||
)
|
||||
|
||||
@@ -4,14 +4,21 @@ Lightweight in-process rate limiting helpers.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
_lock = threading.Lock()
|
||||
_hits: dict[str, deque[float]] = defaultdict(deque)
|
||||
_redis_client: Redis | None = None
|
||||
_redis_unavailable = False
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
@@ -20,16 +27,26 @@ def _client_ip(request: Request) -> str:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def enforce_rate_limit(
|
||||
request: Request,
|
||||
def _get_redis_client() -> Redis | None:
|
||||
global _redis_client
|
||||
if _redis_unavailable:
|
||||
return None
|
||||
if _redis_client is None:
|
||||
settings = get_settings()
|
||||
if not settings.REDIS_URL:
|
||||
return None
|
||||
_redis_client = Redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _enforce_in_memory_rate_limit(
|
||||
*,
|
||||
key: str,
|
||||
scope: str,
|
||||
max_requests: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
now = time.time()
|
||||
ip = _client_ip(request)
|
||||
key = f"{scope}:{ip}"
|
||||
cutoff = now - window_seconds
|
||||
|
||||
with _lock:
|
||||
@@ -43,3 +60,62 @@ def enforce_rate_limit(
|
||||
)
|
||||
dq.append(now)
|
||||
|
||||
|
||||
async def enforce_rate_limit(
|
||||
request: Request,
|
||||
*,
|
||||
scope: str,
|
||||
max_requests: int,
|
||||
window_seconds: int,
|
||||
) -> None:
|
||||
global _redis_unavailable
|
||||
|
||||
ip = _client_ip(request)
|
||||
key = f"{scope}:{ip}"
|
||||
|
||||
redis = _get_redis_client()
|
||||
if redis is not None:
|
||||
try:
|
||||
current = await redis.incr(key)
|
||||
if current == 1:
|
||||
await redis.expire(key, window_seconds)
|
||||
if current > max_requests:
|
||||
ttl = await redis.ttl(key)
|
||||
retry_after = ttl if ttl and ttl > 0 else window_seconds
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Too many requests for {scope}. Please try again later.",
|
||||
headers={"Retry-After": str(retry_after)},
|
||||
)
|
||||
return
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
_redis_unavailable = True
|
||||
logger.warning("Redis rate limiter unavailable; falling back to memory: %s", exc)
|
||||
|
||||
_enforce_in_memory_rate_limit(
|
||||
key=key,
|
||||
scope=scope,
|
||||
max_requests=max_requests,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
|
||||
async def close_rate_limit() -> None:
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
return
|
||||
try:
|
||||
await _redis_client.aclose()
|
||||
finally:
|
||||
_redis_client = None
|
||||
|
||||
|
||||
def reset_rate_limit_state() -> None:
|
||||
"""Reset local limiter state for tests."""
|
||||
global _redis_client, _redis_unavailable
|
||||
_redis_client = None
|
||||
_redis_unavailable = False
|
||||
with _lock:
|
||||
_hits.clear()
|
||||
|
||||
@@ -76,6 +76,9 @@ async def init_db() -> None:
|
||||
Note: In production, use Alembic migrations instead.
|
||||
This is useful for development and testing.
|
||||
"""
|
||||
if settings.ENVIRONMENT == "production":
|
||||
return
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.api.v1.session import (
|
||||
admin_router as adaptive_admin_router,
|
||||
router as adaptive_session_router,
|
||||
)
|
||||
from app.core.rate_limit import close_rate_limit
|
||||
from app.admin_web import (
|
||||
configure_admin_web,
|
||||
router as admin_web_router,
|
||||
@@ -86,6 +87,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# Shutdown: Close database connections
|
||||
if settings.ENABLE_ADMIN:
|
||||
await shutdown_admin_web()
|
||||
await close_rate_limit()
|
||||
await close_db()
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ Exports all SQLAlchemy ORM models for use in the application.
|
||||
from app.database import Base
|
||||
from app.models.ai_generation_run import AIGenerationRun
|
||||
from app.models.item import Item
|
||||
from app.models.report_schedule import ReportScheduleModel
|
||||
from app.models.session import Session
|
||||
from app.models.tryout import Tryout
|
||||
from app.models.tryout_import_snapshot import TryoutImportSnapshot
|
||||
@@ -25,6 +26,7 @@ __all__ = [
|
||||
"TryoutImportSnapshot",
|
||||
"TryoutSnapshotQuestion",
|
||||
"Item",
|
||||
"ReportScheduleModel",
|
||||
"Session",
|
||||
"UserAnswer",
|
||||
"TryoutStats",
|
||||
|
||||
46
app/models/report_schedule.py
Normal file
46
app/models/report_schedule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Persistent report schedule model.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, JSON, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class ReportScheduleModel(Base):
|
||||
"""Database-backed report schedule configuration."""
|
||||
|
||||
__tablename__ = "report_schedules"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
schedule_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
comment="Public schedule identifier",
|
||||
)
|
||||
report_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
schedule: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
tryout_ids: Mapped[list[str]] = mapped_column(JSON, nullable=False)
|
||||
website_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("websites.id", ondelete="CASCADE", onupdate="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
recipients: Mapped[list[str]] = mapped_column(JSON, nullable=False)
|
||||
format: Mapped[str] = mapped_column(String(10), nullable=False, default="xlsx")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
last_run: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
next_run: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_report_schedules_website_active", "website_id", "is_active"),
|
||||
)
|
||||
@@ -5,12 +5,13 @@ Provides admin-specific endpoints for triggering calibration,
|
||||
toggling AI generation, and resetting normalization.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import AuthContext, get_auth_context, require_website_auth
|
||||
from app.core.config import get_settings
|
||||
from app.database import get_db
|
||||
from app.models import Tryout, TryoutStats
|
||||
@@ -23,35 +24,6 @@ router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def get_admin_website_id(
|
||||
x_website_id: Optional[str] = Header(None, alias="X-Website-ID"),
|
||||
) -> int:
|
||||
"""
|
||||
Extract and validate website_id from request header for admin operations.
|
||||
|
||||
Args:
|
||||
x_website_id: Website ID from header
|
||||
|
||||
Returns:
|
||||
Validated website ID as integer
|
||||
|
||||
Raises:
|
||||
HTTPException: If header is missing or invalid
|
||||
"""
|
||||
if x_website_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-Website-ID header is required",
|
||||
)
|
||||
try:
|
||||
return int(x_website_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-Website-ID must be a valid integer",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{tryout_id}/calibrate",
|
||||
summary="Trigger IRT calibration",
|
||||
@@ -60,7 +32,7 @@ def get_admin_website_id(
|
||||
async def admin_trigger_calibration(
|
||||
tryout_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
website_id: int = Depends(get_admin_website_id),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger IRT calibration for all items in a tryout.
|
||||
@@ -79,6 +51,8 @@ async def admin_trigger_calibration(
|
||||
Raises:
|
||||
HTTPException: If tryout not found or calibration fails
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
|
||||
# Verify tryout exists
|
||||
tryout_result = await db.execute(
|
||||
select(Tryout).where(
|
||||
@@ -121,7 +95,7 @@ async def admin_trigger_calibration(
|
||||
async def admin_toggle_ai_generation(
|
||||
tryout_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
website_id: int = Depends(get_admin_website_id),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Toggle AI generation for a tryout.
|
||||
@@ -139,6 +113,8 @@ async def admin_toggle_ai_generation(
|
||||
Raises:
|
||||
HTTPException: If tryout not found
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
|
||||
# Get tryout
|
||||
result = await db.execute(
|
||||
select(Tryout).where(
|
||||
@@ -175,7 +151,7 @@ async def admin_toggle_ai_generation(
|
||||
async def admin_reset_normalization(
|
||||
tryout_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
website_id: int = Depends(get_admin_website_id),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reset normalization for a tryout.
|
||||
@@ -193,6 +169,8 @@ async def admin_reset_normalization(
|
||||
Raises:
|
||||
HTTPException: If tryout or stats not found
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
|
||||
# Get tryout stats
|
||||
stats_result = await db.execute(
|
||||
select(TryoutStats).where(
|
||||
|
||||
@@ -78,7 +78,7 @@ async def generate_preview(
|
||||
- **ai_model**: OpenRouter model to use (default: qwen/qwen2.5-32b-instruct)
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request_http,
|
||||
scope="ai.generate_preview",
|
||||
max_requests=40,
|
||||
@@ -196,7 +196,7 @@ async def generate_save(
|
||||
- **ai_model**: AI model used for generation
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request_http,
|
||||
scope="ai.generate_save",
|
||||
max_requests=40,
|
||||
@@ -291,8 +291,8 @@ async def get_stats(
|
||||
"""
|
||||
Get AI generation statistics.
|
||||
"""
|
||||
require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
stats = await get_ai_stats(db)
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
stats = await get_ai_stats(db, website_id=website_id)
|
||||
|
||||
return AIStatsResponse(
|
||||
total_ai_items=stats["total_ai_items"],
|
||||
|
||||
@@ -77,7 +77,7 @@ async def preview_import(
|
||||
HTTPException: If file format is invalid or parsing fails
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="import.preview",
|
||||
max_requests=30,
|
||||
@@ -181,7 +181,7 @@ async def import_questions(
|
||||
HTTPException: If file format is invalid, validation fails, or import fails
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="import.questions",
|
||||
max_requests=20,
|
||||
@@ -351,7 +351,7 @@ async def preview_tryout_json(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="import.tryout_json_preview",
|
||||
max_requests=30,
|
||||
@@ -394,7 +394,7 @@ async def import_tryout_json(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="import.tryout_json",
|
||||
max_requests=20,
|
||||
|
||||
@@ -85,6 +85,15 @@ async def get_student_performance_report(
|
||||
Returns individual student records and/or aggregate statistics.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"})
|
||||
scoped_wp_user_id = None
|
||||
if auth.role == "student":
|
||||
if not auth.wp_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Student reports require an authenticated WordPress user",
|
||||
)
|
||||
scoped_wp_user_id = auth.wp_user_id
|
||||
|
||||
date_range = None
|
||||
if date_start or date_end:
|
||||
date_range = {}
|
||||
@@ -99,6 +108,7 @@ async def get_student_performance_report(
|
||||
db=db,
|
||||
date_range=date_range,
|
||||
format_type=format_type,
|
||||
wp_user_id=scoped_wp_user_id,
|
||||
)
|
||||
|
||||
return _convert_student_performance_report(report)
|
||||
@@ -361,7 +371,8 @@ async def create_report_schedule(
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
ensure_website_scope_matches(website_id, request.website_id)
|
||||
schedule_id = schedule_report(
|
||||
schedule_id = await schedule_report(
|
||||
db,
|
||||
report_type=request.report_type,
|
||||
schedule=request.schedule,
|
||||
tryout_ids=request.tryout_ids,
|
||||
@@ -370,7 +381,7 @@ async def create_report_schedule(
|
||||
export_format=request.export_format,
|
||||
)
|
||||
|
||||
scheduled = get_scheduled_report(schedule_id)
|
||||
scheduled = await get_scheduled_report(db, schedule_id)
|
||||
|
||||
return ReportScheduleResponse(
|
||||
schedule_id=schedule_id,
|
||||
@@ -387,6 +398,7 @@ async def create_report_schedule(
|
||||
)
|
||||
async def get_scheduled_report_details(
|
||||
schedule_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> ReportScheduleOutput:
|
||||
"""
|
||||
@@ -395,7 +407,7 @@ async def get_scheduled_report_details(
|
||||
Returns the configuration and status of a scheduled report.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
scheduled = get_scheduled_report(schedule_id)
|
||||
scheduled = await get_scheduled_report(db, schedule_id)
|
||||
|
||||
if not scheduled:
|
||||
raise HTTPException(
|
||||
@@ -431,6 +443,7 @@ async def get_scheduled_report_details(
|
||||
description="List all scheduled reports for a website.",
|
||||
)
|
||||
async def list_scheduled_reports_endpoint(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> List[ReportScheduleOutput]:
|
||||
"""
|
||||
@@ -439,7 +452,7 @@ async def list_scheduled_reports_endpoint(
|
||||
Returns all scheduled reports for the current website.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
reports = list_scheduled_reports(website_id=website_id)
|
||||
reports = await list_scheduled_reports(db, website_id=website_id)
|
||||
|
||||
return [
|
||||
ReportScheduleOutput(
|
||||
@@ -466,6 +479,7 @@ async def list_scheduled_reports_endpoint(
|
||||
)
|
||||
async def cancel_scheduled_report_endpoint(
|
||||
schedule_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -474,7 +488,7 @@ async def cancel_scheduled_report_endpoint(
|
||||
Removes the scheduled report from the system.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
scheduled = get_scheduled_report(schedule_id)
|
||||
scheduled = await get_scheduled_report(db, schedule_id)
|
||||
|
||||
if not scheduled:
|
||||
raise HTTPException(
|
||||
@@ -488,7 +502,7 @@ async def cancel_scheduled_report_endpoint(
|
||||
detail="Access denied to this scheduled report",
|
||||
)
|
||||
|
||||
success = cancel_scheduled_report(schedule_id)
|
||||
success = await cancel_scheduled_report(db, schedule_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
@@ -523,7 +537,7 @@ async def export_scheduled_report(
|
||||
Generates the report and returns it as a file download.
|
||||
"""
|
||||
website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
scheduled = get_scheduled_report(schedule_id)
|
||||
scheduled = await get_scheduled_report(db, schedule_id)
|
||||
|
||||
if not scheduled:
|
||||
raise HTTPException(
|
||||
@@ -536,6 +550,11 @@ async def export_scheduled_report(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to this scheduled report",
|
||||
)
|
||||
if not scheduled.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Scheduled report is inactive",
|
||||
)
|
||||
|
||||
# Generate report based on type
|
||||
report = None
|
||||
|
||||
@@ -10,6 +10,7 @@ Endpoints:
|
||||
from datetime import datetime, timezone
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
@@ -122,6 +123,27 @@ async def complete_session(
|
||||
items = {item.id: item for item in items_result.scalars().all()}
|
||||
|
||||
# Process each answer
|
||||
submitted_item_ids = [answer.item_id for answer in request.user_answers]
|
||||
if len(submitted_item_ids) != len(set(submitted_item_ids)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Duplicate item answers are not allowed in a session completion",
|
||||
)
|
||||
|
||||
existing_answers_result = await db.execute(
|
||||
select(UserAnswer.item_id).where(UserAnswer.session_id == session.session_id)
|
||||
)
|
||||
existing_answered_item_ids = {row[0] for row in existing_answers_result.all()}
|
||||
duplicate_existing_ids = sorted(set(submitted_item_ids) & existing_answered_item_ids)
|
||||
if duplicate_existing_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail={
|
||||
"message": "One or more items were already answered for this session",
|
||||
"item_ids": duplicate_existing_ids,
|
||||
},
|
||||
)
|
||||
|
||||
total_benar = 0
|
||||
total_bobot_earned = 0.0
|
||||
user_answer_records = []
|
||||
@@ -234,7 +256,13 @@ async def complete_session(
|
||||
await update_tryout_stats(db, website_id, session.tryout_id, nm)
|
||||
|
||||
# Commit all changes
|
||||
await db.commit()
|
||||
try:
|
||||
await db.commit()
|
||||
except IntegrityError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Duplicate item answer detected for this session",
|
||||
) from exc
|
||||
|
||||
# Refresh to get updated relationships
|
||||
await db.refresh(session)
|
||||
@@ -261,7 +289,6 @@ async def complete_session(
|
||||
id=ua.id,
|
||||
item_id=ua.item_id,
|
||||
response=ua.response,
|
||||
is_correct=ua.is_correct,
|
||||
time_spent=ua.time_spent,
|
||||
bobot_earned=ua.bobot_earned,
|
||||
scoring_mode_used=ua.scoring_mode_used,
|
||||
|
||||
@@ -15,7 +15,13 @@ from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.core.auth import issue_access_token
|
||||
from app.core.auth import (
|
||||
AuthContext,
|
||||
ensure_website_scope_matches,
|
||||
get_auth_context,
|
||||
issue_access_token,
|
||||
require_website_auth,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.models.website import Website
|
||||
from app.schemas.wordpress import (
|
||||
@@ -44,6 +50,16 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/wordpress", tags=["wordpress"])
|
||||
|
||||
|
||||
def _api_role_from_wordpress_roles(roles: list[str]) -> str:
|
||||
"""Map WordPress roles to API roles used by route authorization."""
|
||||
normalized_roles = {str(role).strip().lower() for role in roles}
|
||||
if normalized_roles & {"super_admin", "system_admin"}:
|
||||
return "system_admin"
|
||||
if normalized_roles & {"administrator", "admin"}:
|
||||
return "admin"
|
||||
return "student"
|
||||
|
||||
|
||||
def get_website_id_from_header(
|
||||
x_website_id: Optional[str] = Header(None, alias="X-Website-ID"),
|
||||
) -> int:
|
||||
@@ -132,7 +148,7 @@ async def sync_users_endpoint(
|
||||
Raises:
|
||||
HTTPException: If website not found, token invalid, or API error
|
||||
"""
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
request,
|
||||
scope="wordpress.sync_users",
|
||||
max_requests=20,
|
||||
@@ -230,7 +246,7 @@ async def verify_session_endpoint(
|
||||
Raises:
|
||||
HTTPException: If website not found or API error
|
||||
"""
|
||||
enforce_rate_limit(
|
||||
await enforce_rate_limit(
|
||||
http_request,
|
||||
scope="wordpress.verify_session",
|
||||
max_requests=60,
|
||||
@@ -273,7 +289,7 @@ async def verify_session_endpoint(
|
||||
},
|
||||
access_token=issue_access_token(
|
||||
website_id=request.website_id,
|
||||
role="student",
|
||||
role=_api_role_from_wordpress_roles(wp_user_info.roles),
|
||||
wp_user_id=request.wp_user_id,
|
||||
expires_in_seconds=3600 * 24,
|
||||
),
|
||||
@@ -310,6 +326,7 @@ async def verify_session_endpoint(
|
||||
async def get_website_users(
|
||||
website_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> UserListResponse:
|
||||
@@ -328,6 +345,9 @@ async def get_website_users(
|
||||
Raises:
|
||||
HTTPException: If website not found
|
||||
"""
|
||||
auth_website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"})
|
||||
ensure_website_scope_matches(auth_website_id, website_id)
|
||||
|
||||
# Validate website exists
|
||||
await get_valid_website(website_id, db)
|
||||
|
||||
@@ -374,6 +394,7 @@ async def get_user_endpoint(
|
||||
website_id: int,
|
||||
wp_user_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> WordPressUserResponse:
|
||||
"""
|
||||
Get a specific user by WordPress user ID.
|
||||
@@ -389,6 +410,16 @@ async def get_user_endpoint(
|
||||
Raises:
|
||||
HTTPException: If website or user not found
|
||||
"""
|
||||
auth_website_id = require_website_auth(
|
||||
auth, allowed_roles={"student", "admin", "system_admin"}
|
||||
)
|
||||
ensure_website_scope_matches(auth_website_id, website_id)
|
||||
if auth.role == "student" and auth.wp_user_id != wp_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User does not belong to this authenticated user",
|
||||
)
|
||||
|
||||
# Validate website exists
|
||||
await get_valid_website(website_id, db)
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ class UserAnswerOutput(BaseModel):
|
||||
id: int
|
||||
item_id: int
|
||||
response: str
|
||||
is_correct: bool
|
||||
time_spent: int
|
||||
bobot_earned: float
|
||||
scoring_mode_used: str
|
||||
@@ -37,6 +36,12 @@ class UserAnswerOutput(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserAnswerReviewOutput(UserAnswerOutput):
|
||||
"""Review output for a single answer."""
|
||||
|
||||
is_correct: bool
|
||||
|
||||
|
||||
class SessionCompleteResponse(BaseModel):
|
||||
"""Response schema for completed session with CTT scores."""
|
||||
|
||||
@@ -66,6 +71,12 @@ class SessionCompleteResponse(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SessionCompleteAdminResponse(SessionCompleteResponse):
|
||||
"""Completed session response with answer correctness for admin/review contexts."""
|
||||
|
||||
user_answers: List[UserAnswerReviewOutput]
|
||||
|
||||
|
||||
class SessionCreateRequest(BaseModel):
|
||||
"""Request schema for creating a new session."""
|
||||
|
||||
|
||||
@@ -715,7 +715,7 @@ async def generate_questions_batch(
|
||||
return generated_items
|
||||
|
||||
|
||||
async def get_ai_stats(db: AsyncSession) -> Dict[str, Any]:
|
||||
async def get_ai_stats(db: AsyncSession, website_id: int | None = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get AI generation statistics.
|
||||
|
||||
@@ -725,16 +725,18 @@ async def get_ai_stats(db: AsyncSession) -> Dict[str, Any]:
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
filters = [Item.generated_by == "ai"]
|
||||
if website_id is not None:
|
||||
filters.append(Item.website_id == website_id)
|
||||
|
||||
# Total AI-generated items
|
||||
total_result = await db.execute(
|
||||
select(func.count(Item.id)).where(Item.generated_by == "ai")
|
||||
)
|
||||
total_result = await db.execute(select(func.count(Item.id)).where(*filters))
|
||||
total_ai_items = total_result.scalar() or 0
|
||||
|
||||
# Items by model
|
||||
model_result = await db.execute(
|
||||
select(Item.ai_model, func.count(Item.id))
|
||||
.where(Item.generated_by == "ai")
|
||||
.where(*filters)
|
||||
.where(Item.ai_model.isnot(None))
|
||||
.group_by(Item.ai_model)
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.item import Item
|
||||
from app.models.report_schedule import ReportScheduleModel
|
||||
from app.models.session import Session
|
||||
from app.models.tryout import Tryout
|
||||
from app.models.tryout_stats import TryoutStats
|
||||
@@ -256,7 +257,8 @@ async def generate_student_performance_report(
|
||||
website_id: int,
|
||||
db: AsyncSession,
|
||||
date_range: Optional[Dict[str, datetime]] = None,
|
||||
format_type: Literal["individual", "aggregate", "both"] = "both"
|
||||
format_type: Literal["individual", "aggregate", "both"] = "both",
|
||||
wp_user_id: Optional[str] = None,
|
||||
) -> StudentPerformanceReport:
|
||||
"""
|
||||
Generate student performance report.
|
||||
@@ -267,6 +269,7 @@ async def generate_student_performance_report(
|
||||
db: Database session
|
||||
date_range: Optional date range filter {"start": datetime, "end": datetime}
|
||||
format_type: Report format - individual, aggregate, or both
|
||||
wp_user_id: Optional WordPress user filter for student-scoped reports
|
||||
|
||||
Returns:
|
||||
StudentPerformanceReport with aggregate stats and/or individual records
|
||||
@@ -288,6 +291,9 @@ async def generate_student_performance_report(
|
||||
if date_range.get("end"):
|
||||
query = query.where(Session.start_time <= date_range["end"])
|
||||
|
||||
if wp_user_id is not None:
|
||||
query = query.where(Session.wp_user_id == wp_user_id)
|
||||
|
||||
query = query.order_by(Session.NN.desc().nullslast())
|
||||
|
||||
result = await db.execute(query)
|
||||
@@ -1382,11 +1388,34 @@ class ReportSchedule:
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
# In-memory store for scheduled reports (in production, use database)
|
||||
_scheduled_reports: Dict[str, ReportSchedule] = {}
|
||||
def _calculate_next_run(schedule: Literal["daily", "weekly", "monthly"]) -> datetime:
|
||||
now = datetime.now(timezone.utc)
|
||||
if schedule == "daily":
|
||||
return now + timedelta(days=1)
|
||||
if schedule == "weekly":
|
||||
return now + timedelta(weeks=1)
|
||||
return now + timedelta(days=30)
|
||||
|
||||
|
||||
def schedule_report(
|
||||
def _schedule_from_model(row: ReportScheduleModel) -> ReportSchedule:
|
||||
return ReportSchedule(
|
||||
schedule_id=row.schedule_id,
|
||||
report_type=row.report_type,
|
||||
schedule=row.schedule,
|
||||
tryout_ids=list(row.tryout_ids or []),
|
||||
website_id=row.website_id,
|
||||
recipients=list(row.recipients or []),
|
||||
format=row.format,
|
||||
created_at=row.created_at,
|
||||
last_run=row.last_run,
|
||||
next_run=row.next_run,
|
||||
is_active=row.is_active,
|
||||
)
|
||||
|
||||
|
||||
async def schedule_report(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
report_type: Literal["student_performance", "item_analysis", "calibration_status", "tryout_comparison"],
|
||||
schedule: Literal["daily", "weekly", "monthly"],
|
||||
tryout_ids: List[str],
|
||||
@@ -1412,16 +1441,7 @@ def schedule_report(
|
||||
|
||||
schedule_id = str(uuid.uuid4())
|
||||
|
||||
# Calculate next run time
|
||||
now = datetime.now(timezone.utc)
|
||||
if schedule == "daily":
|
||||
next_run = now + timedelta(days=1)
|
||||
elif schedule == "weekly":
|
||||
next_run = now + timedelta(weeks=1)
|
||||
else: # monthly
|
||||
next_run = now + timedelta(days=30)
|
||||
|
||||
report_schedule = ReportSchedule(
|
||||
report_schedule = ReportScheduleModel(
|
||||
schedule_id=schedule_id,
|
||||
report_type=report_type,
|
||||
schedule=schedule,
|
||||
@@ -1429,35 +1449,54 @@ def schedule_report(
|
||||
website_id=website_id,
|
||||
recipients=recipients,
|
||||
format=export_format,
|
||||
next_run=next_run,
|
||||
next_run=_calculate_next_run(schedule),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
_scheduled_reports[schedule_id] = report_schedule
|
||||
db.add(report_schedule)
|
||||
await db.flush()
|
||||
logger.info(f"Scheduled report {schedule_id}: {report_type} {schedule}")
|
||||
|
||||
return schedule_id
|
||||
|
||||
|
||||
def get_scheduled_report(schedule_id: str) -> Optional[ReportSchedule]:
|
||||
async def get_scheduled_report(db: AsyncSession, schedule_id: str) -> Optional[ReportSchedule]:
|
||||
"""Get a scheduled report by ID."""
|
||||
return _scheduled_reports.get(schedule_id)
|
||||
result = await db.execute(
|
||||
select(ReportScheduleModel).where(ReportScheduleModel.schedule_id == schedule_id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return _schedule_from_model(row) if row else None
|
||||
|
||||
|
||||
def list_scheduled_reports(website_id: Optional[int] = None) -> List[ReportSchedule]:
|
||||
async def list_scheduled_reports(
|
||||
db: AsyncSession,
|
||||
website_id: Optional[int] = None,
|
||||
) -> List[ReportSchedule]:
|
||||
"""List all scheduled reports, optionally filtered by website."""
|
||||
reports = list(_scheduled_reports.values())
|
||||
if website_id:
|
||||
reports = [r for r in reports if r.website_id == website_id]
|
||||
return reports
|
||||
query = (
|
||||
select(ReportScheduleModel)
|
||||
.where(ReportScheduleModel.is_active == True)
|
||||
.order_by(ReportScheduleModel.created_at.desc())
|
||||
)
|
||||
if website_id is not None:
|
||||
query = query.where(ReportScheduleModel.website_id == website_id)
|
||||
result = await db.execute(query)
|
||||
return [_schedule_from_model(row) for row in result.scalars().all()]
|
||||
|
||||
|
||||
def cancel_scheduled_report(schedule_id: str) -> bool:
|
||||
async def cancel_scheduled_report(db: AsyncSession, schedule_id: str) -> bool:
|
||||
"""Cancel a scheduled report."""
|
||||
if schedule_id in _scheduled_reports:
|
||||
del _scheduled_reports[schedule_id]
|
||||
logger.info(f"Cancelled scheduled report {schedule_id}")
|
||||
return True
|
||||
return False
|
||||
result = await db.execute(
|
||||
select(ReportScheduleModel).where(ReportScheduleModel.schedule_id == schedule_id)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return False
|
||||
row.is_active = False
|
||||
await db.flush()
|
||||
logger.info(f"Cancelled scheduled report {schedule_id}")
|
||||
return True
|
||||
|
||||
|
||||
# Export public API
|
||||
|
||||
212
tests/test_operational_hardening.py
Normal file
212
tests/test_operational_hardening.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core import rate_limit
|
||||
from app.core.config import Settings
|
||||
from app.models.report_schedule import ReportScheduleModel
|
||||
from app.services import ai_generation
|
||||
from app.services.reporting import (
|
||||
cancel_scheduled_report,
|
||||
get_scheduled_report,
|
||||
list_scheduled_reports,
|
||||
schedule_report,
|
||||
)
|
||||
|
||||
|
||||
class DummyRequest:
|
||||
client = SimpleNamespace(host="127.0.0.1")
|
||||
|
||||
|
||||
class DummyScalarResult:
|
||||
def __init__(self, value):
|
||||
self._value = value
|
||||
|
||||
def scalar_one_or_none(self):
|
||||
return self._value
|
||||
|
||||
def scalar(self):
|
||||
return self._value
|
||||
|
||||
|
||||
class DummyScalars:
|
||||
def __init__(self, values):
|
||||
self._values = values
|
||||
|
||||
def all(self):
|
||||
return self._values
|
||||
|
||||
|
||||
class DummyListResult:
|
||||
def __init__(self, values):
|
||||
self._values = values
|
||||
|
||||
def scalars(self):
|
||||
return DummyScalars(self._values)
|
||||
|
||||
|
||||
class DummyRowsResult:
|
||||
def __init__(self, values):
|
||||
self._values = values
|
||||
|
||||
def all(self):
|
||||
return self._values
|
||||
|
||||
|
||||
class DummyDb:
|
||||
def __init__(self, execute_results=None):
|
||||
self.execute_results = list(execute_results or [])
|
||||
self.added = []
|
||||
self.flushed = False
|
||||
|
||||
def add(self, row):
|
||||
self.added.append(row)
|
||||
|
||||
async def flush(self):
|
||||
self.flushed = True
|
||||
|
||||
async def execute(self, _query):
|
||||
return self.execute_results.pop(0)
|
||||
|
||||
|
||||
class DummyRedis:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
async def incr(self, _key):
|
||||
self.calls += 1
|
||||
return self.calls
|
||||
|
||||
async def expire(self, _key, _seconds):
|
||||
return True
|
||||
|
||||
async def ttl(self, _key):
|
||||
return 60
|
||||
|
||||
|
||||
def test_ai_stats_accepts_website_scope(monkeypatch):
|
||||
captured_queries = []
|
||||
|
||||
class CaptureDb:
|
||||
async def execute(self, query):
|
||||
captured_queries.append(str(query))
|
||||
if len(captured_queries) == 1:
|
||||
return DummyScalarResult(0)
|
||||
return DummyRowsResult([])
|
||||
|
||||
asyncio.run(ai_generation.get_ai_stats(CaptureDb(), website_id=9))
|
||||
|
||||
assert all("items.website_id" in query for query in captured_queries)
|
||||
|
||||
|
||||
def test_production_init_db_skips_create_all(monkeypatch):
|
||||
import app.database as database
|
||||
|
||||
class ExplodingEngine:
|
||||
def begin(self):
|
||||
raise AssertionError("create_all should not run in production")
|
||||
|
||||
monkeypatch.setattr(database, "settings", Settings(ENVIRONMENT="production"))
|
||||
monkeypatch.setattr(database, "engine", ExplodingEngine())
|
||||
|
||||
asyncio.run(database.init_db())
|
||||
|
||||
|
||||
def test_rate_limit_uses_redis_and_blocks_when_limit_exceeded(monkeypatch):
|
||||
dummy_redis = DummyRedis()
|
||||
rate_limit.reset_rate_limit_state()
|
||||
monkeypatch.setattr(rate_limit, "_get_redis_client", lambda: dummy_redis)
|
||||
|
||||
asyncio.run(
|
||||
rate_limit.enforce_rate_limit(
|
||||
DummyRequest(),
|
||||
scope="test.redis",
|
||||
max_requests=1,
|
||||
window_seconds=60,
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
rate_limit.enforce_rate_limit(
|
||||
DummyRequest(),
|
||||
scope="test.redis",
|
||||
max_requests=1,
|
||||
window_seconds=60,
|
||||
)
|
||||
)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
def test_rate_limit_falls_back_to_memory_when_redis_unavailable(monkeypatch):
|
||||
rate_limit.reset_rate_limit_state()
|
||||
monkeypatch.setattr(rate_limit, "_get_redis_client", lambda: None)
|
||||
|
||||
asyncio.run(
|
||||
rate_limit.enforce_rate_limit(
|
||||
DummyRequest(),
|
||||
scope="test.memory",
|
||||
max_requests=1,
|
||||
window_seconds=60,
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
rate_limit.enforce_rate_limit(
|
||||
DummyRequest(),
|
||||
scope="test.memory",
|
||||
max_requests=1,
|
||||
window_seconds=60,
|
||||
)
|
||||
)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
def test_schedule_report_persists_model_row():
|
||||
db = DummyDb()
|
||||
|
||||
schedule_id = asyncio.run(
|
||||
schedule_report(
|
||||
db,
|
||||
report_type="student_performance",
|
||||
schedule="daily",
|
||||
tryout_ids=["t1"],
|
||||
website_id=3,
|
||||
recipients=["ops@example.com"],
|
||||
export_format="xlsx",
|
||||
)
|
||||
)
|
||||
|
||||
assert db.flushed is True
|
||||
assert isinstance(db.added[0], ReportScheduleModel)
|
||||
assert db.added[0].schedule_id == schedule_id
|
||||
assert db.added[0].website_id == 3
|
||||
|
||||
|
||||
def test_schedule_helpers_read_list_and_soft_cancel():
|
||||
row = ReportScheduleModel(
|
||||
schedule_id="sched-1",
|
||||
report_type="student_performance",
|
||||
schedule="daily",
|
||||
tryout_ids=["t1"],
|
||||
website_id=3,
|
||||
recipients=["ops@example.com"],
|
||||
format="xlsx",
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
get_db = DummyDb([DummyScalarResult(row)])
|
||||
listed_db = DummyDb([DummyListResult([row])])
|
||||
cancel_db = DummyDb([DummyScalarResult(row)])
|
||||
|
||||
got = asyncio.run(get_scheduled_report(get_db, "sched-1"))
|
||||
listed = asyncio.run(list_scheduled_reports(listed_db, website_id=3))
|
||||
cancelled = asyncio.run(cancel_scheduled_report(cancel_db, "sched-1"))
|
||||
|
||||
assert got.schedule_id == "sched-1"
|
||||
assert listed[0].website_id == 3
|
||||
assert cancelled is True
|
||||
assert row.is_active is False
|
||||
132
tests/test_security_regressions.py
Normal file
132
tests/test_security_regressions.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi.params import Depends
|
||||
|
||||
from app.api.v1.session import SubmitAnswerResponse
|
||||
from app.core.auth import AuthContext, get_auth_context
|
||||
from app.routers import admin as admin_router
|
||||
from app.routers import reports as reports_router
|
||||
from app.routers import wordpress as wordpress_router
|
||||
from app.schemas.session import SessionCompleteResponse, UserAnswerOutput
|
||||
from app.services.reporting import AggregatePerformanceStats, StudentPerformanceReport
|
||||
|
||||
|
||||
def _depends_on_auth(callable_obj, parameter_name: str = "auth") -> bool:
|
||||
parameter = inspect.signature(callable_obj).parameters[parameter_name]
|
||||
default = parameter.default
|
||||
return isinstance(default, Depends) and default.dependency is get_auth_context
|
||||
|
||||
|
||||
def test_admin_actions_require_signed_auth_context():
|
||||
assert _depends_on_auth(admin_router.admin_trigger_calibration)
|
||||
assert _depends_on_auth(admin_router.admin_toggle_ai_generation)
|
||||
assert _depends_on_auth(admin_router.admin_reset_normalization)
|
||||
|
||||
|
||||
def test_wordpress_user_lookup_routes_require_signed_auth_context():
|
||||
assert _depends_on_auth(wordpress_router.get_website_users)
|
||||
assert _depends_on_auth(wordpress_router.get_user_endpoint)
|
||||
|
||||
|
||||
def test_wordpress_roles_map_to_api_admin_roles():
|
||||
assert wordpress_router._api_role_from_wordpress_roles(["subscriber"]) == "student"
|
||||
assert wordpress_router._api_role_from_wordpress_roles(["administrator"]) == "admin"
|
||||
assert wordpress_router._api_role_from_wordpress_roles(["super_admin"]) == "system_admin"
|
||||
|
||||
|
||||
def test_adaptive_submit_response_does_not_expose_answer_key_or_correctness():
|
||||
payload = SubmitAnswerResponse(theta=0.12, theta_se=0.8).model_dump()
|
||||
|
||||
assert "is_correct" not in payload
|
||||
assert "correct_answer" not in payload
|
||||
assert "explanation" not in payload
|
||||
|
||||
|
||||
def test_session_completion_answer_output_does_not_expose_correctness():
|
||||
answer_payload = UserAnswerOutput(
|
||||
id=1,
|
||||
item_id=10,
|
||||
response="A",
|
||||
time_spent=12,
|
||||
bobot_earned=0.5,
|
||||
scoring_mode_used="ctt",
|
||||
).model_dump()
|
||||
|
||||
assert "is_correct" not in answer_payload
|
||||
|
||||
response_payload = SessionCompleteResponse(
|
||||
id=1,
|
||||
session_id="s-1",
|
||||
wp_user_id="wp-1",
|
||||
website_id=2,
|
||||
tryout_id="tryout-1",
|
||||
start_time=datetime.now(timezone.utc),
|
||||
end_time=datetime.now(timezone.utc),
|
||||
is_completed=True,
|
||||
scoring_mode_used="ctt",
|
||||
total_benar=1,
|
||||
total_bobot_earned=0.5,
|
||||
NM=500,
|
||||
NN=500,
|
||||
rataan_used=500,
|
||||
sb_used=100,
|
||||
user_answers=[
|
||||
UserAnswerOutput(
|
||||
id=1,
|
||||
item_id=10,
|
||||
response="A",
|
||||
time_spent=12,
|
||||
bobot_earned=0.5,
|
||||
scoring_mode_used="ctt",
|
||||
)
|
||||
],
|
||||
).model_dump()
|
||||
|
||||
assert "is_correct" not in response_payload["user_answers"][0]
|
||||
|
||||
|
||||
def test_student_performance_report_is_scoped_to_student_user(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
async def fake_generate_student_performance_report(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return StudentPerformanceReport(
|
||||
generated_at=datetime.now(timezone.utc),
|
||||
tryout_id=kwargs["tryout_id"],
|
||||
website_id=kwargs["website_id"],
|
||||
date_range=kwargs["date_range"],
|
||||
aggregate=AggregatePerformanceStats(
|
||||
tryout_id=kwargs["tryout_id"],
|
||||
participant_count=0,
|
||||
avg_nm=None,
|
||||
std_nm=None,
|
||||
min_nm=None,
|
||||
max_nm=None,
|
||||
median_nm=None,
|
||||
avg_nn=None,
|
||||
std_nn=None,
|
||||
avg_theta=None,
|
||||
pass_rate=0.0,
|
||||
avg_time_spent=0.0,
|
||||
),
|
||||
individual_records=[],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
reports_router,
|
||||
"generate_student_performance_report",
|
||||
fake_generate_student_performance_report,
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
reports_router.get_student_performance_report(
|
||||
tryout_id="tryout-1",
|
||||
db=object(),
|
||||
auth=AuthContext(website_id=5, role="student", wp_user_id="wp-1"),
|
||||
)
|
||||
)
|
||||
|
||||
assert captured["website_id"] == 5
|
||||
assert captured["wp_user_id"] == "wp-1"
|
||||
Reference in New Issue
Block a user