diff --git a/app/admin_web.py b/app/admin_web.py index 96e746d..6064bc1 100644 --- a/app/admin_web.py +++ b/app/admin_web.py @@ -20,7 +20,7 @@ from sqlalchemy import func, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from starlette.responses import HTMLResponse, RedirectResponse -from starlette.status import HTTP_303_SEE_OTHER, HTTP_401_UNAUTHORIZED +from starlette.status import HTTP_303_SEE_OTHER, HTTP_401_UNAUTHORIZED, HTTP_429_TOO_MANY_REQUESTS from app.core.config import get_settings from app.database import get_db @@ -53,9 +53,13 @@ settings = get_settings() router = APIRouter(prefix="/admin", tags=["admin-web"]) SESSION_COOKIE = "access_token" +CSRF_COOKIE = "admin_csrf_token" SESSION_PREFIX = "admin:session:" IMPORT_PREVIEW_PREFIX = "admin:import-preview:" IMPORT_PREVIEW_TTL_SECONDS = 900 +LOGIN_RATE_LIMIT_PREFIX = "admin:login:attempts:" +LOGIN_RATE_LIMIT_MAX_ATTEMPTS = 10 +LOGIN_RATE_LIMIT_WINDOW_SECONDS = 300 _admin_redis = None @@ -153,10 +157,27 @@ def _render_auth_page( """ - return HTMLResponse(html, status_code=status_code) + csrf_token = request.cookies.get(CSRF_COOKIE) or secrets.token_urlsafe(24) + csrf_input = f'' + html = re.sub( + r'(]*method="post"[^>]*>)', + r"\1" + csrf_input, + html, + flags=re.IGNORECASE, + ) + response = HTMLResponse(html, status_code=status_code) + response.set_cookie( + CSRF_COOKIE, + csrf_token, + path="/admin", + httponly=False, + secure=settings.ENVIRONMENT == "production", + samesite="lax", + ) + return response -def _render_admin_page(title: str, page_title: str, body: str) -> HTMLResponse: +def _render_admin_page(request: Request, title: str, page_title: str, body: str) -> HTMLResponse: html = f""" @@ -213,7 +234,46 @@ def _render_admin_page(title: str, page_title: str, body: str) -> HTMLResponse: """ - return HTMLResponse(html) + csrf_token = request.cookies.get(CSRF_COOKIE) or secrets.token_urlsafe(24) + csrf_input = f'' + html = re.sub( + r'(]*method="post"[^>]*>)', + r"\1" + csrf_input, + html, + flags=re.IGNORECASE, + ) + response = HTMLResponse(html) + response.set_cookie( + CSRF_COOKIE, + csrf_token, + path="/admin", + httponly=False, + secure=settings.ENVIRONMENT == "production", + samesite="lax", + ) + return response + + +def _verify_csrf(request: Request, csrf_token: str | None) -> None: + cookie_token = request.cookies.get(CSRF_COOKIE) + if not cookie_token or not csrf_token: + raise HTTPException(status_code=403, detail="CSRF validation failed") + if not secrets.compare_digest(cookie_token, csrf_token): + raise HTTPException(status_code=403, detail="CSRF validation failed") + + +async def _enforce_csrf(request: Request) -> None: + form = await request.form() + _verify_csrf(request, form.get("csrf_token")) + + +async def _csrf_route_guard(request: Request) -> None: + if request.method.upper() != "POST": + return + await _enforce_csrf(request) + + +router.dependencies.append(Depends(_csrf_route_guard)) def _table(headers: list[str], rows: list[list[Any]]) -> str: @@ -1052,10 +1112,58 @@ async def login_submit( password: str = Form(...), remember_me: str | None = Form(None), ): + + if _admin_redis is None: + body = """ +
Admin backend is temporarily unavailable. Please try again.
+
+ + + + + + +
+ """ + return _render_auth_page( + request, + "Admin Login", + "Use the configured admin credentials to access the dashboard.", + body, + status_code=503, + ) + + client_ip = request.client.host if request.client else "unknown" + rate_limit_key = f"{LOGIN_RATE_LIMIT_PREFIX}{client_ip}" + attempts_raw = await _admin_redis.get(rate_limit_key) + attempts = int(attempts_raw) if attempts_raw else 0 + if attempts >= LOGIN_RATE_LIMIT_MAX_ATTEMPTS: + body = """ +
Too many login attempts. Please wait a few minutes and try again.
+
+ + + + + + +
+ """ + return _render_auth_page( + request, + "Admin Login", + "Use the configured admin credentials to access the dashboard.", + body, + status_code=HTTP_429_TOO_MANY_REQUESTS, + ) + if not ( secrets.compare_digest(username, settings.ADMIN_USERNAME) and secrets.compare_digest(password, settings.ADMIN_PASSWORD) ): + attempts = await _admin_redis.incr(rate_limit_key) + if attempts == 1: + await _admin_redis.expire(rate_limit_key, LOGIN_RATE_LIMIT_WINDOW_SECONDS) body = f"""
Invalid username or password.
@@ -1075,11 +1183,21 @@ async def login_submit( status_code=HTTP_401_UNAUTHORIZED, ) + await _admin_redis.delete(rate_limit_key) + expire = settings.ADMIN_SESSION_EXPIRE_SECONDS response = _dashboard_redirect() + secure_cookie = settings.ENVIRONMENT == "production" if remember_me == "on": expire = max(expire, 3600 * 24 * 30) - response.set_cookie("remember_me", "on", expires=expire, path="/admin") + response.set_cookie( + "remember_me", + "on", + expires=expire, + path="/admin", + secure=secure_cookie, + samesite="lax", + ) else: response.delete_cookie("remember_me", path="/admin") @@ -1090,6 +1208,7 @@ async def login_submit( expires=expire, path="/admin", httponly=True, + secure=secure_cookie, samesite="lax", ) await _admin_redis.set(f"{SESSION_PREFIX}{token}", settings.ADMIN_USERNAME, ex=expire) @@ -1179,7 +1298,7 @@ async def dashboard_view(request: Request, db: AsyncSession = Depends(get_db)):

Open AI Playground

""" - return _render_admin_page("IRT Bank Soal Admin", "Dashboard", body) + return _render_admin_page(request, "IRT Bank Soal Admin", "Dashboard", body) @router.get("/websites", include_in_schema=False) @@ -1191,7 +1310,7 @@ async def websites_view(request: Request, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Website).order_by(Website.id.asc())) websites = list(result.scalars().all()) body = _websites_form_body(websites) - return _render_admin_page("Websites", "Websites", body) + return _render_admin_page(request, "Websites", "Websites", body) @router.post("/websites", include_in_schema=False) @@ -1217,7 +1336,7 @@ async def websites_submit( site_name=site_name, site_url=site_url, ) - return _render_admin_page("Websites", "Websites", body) + return _render_admin_page(request, "Websites", "Websites", body) if not normalized_url.startswith(("http://", "https://")): result = await db.execute(select(Website).order_by(Website.id.asc())) @@ -1228,7 +1347,7 @@ async def websites_submit( site_name=site_name, site_url=site_url, ) - return _render_admin_page("Websites", "Websites", body) + return _render_admin_page(request, "Websites", "Websites", body) website = Website(site_name=normalized_name, site_url=normalized_url) db.add(website) @@ -1244,7 +1363,7 @@ async def websites_submit( site_name=site_name, site_url=site_url, ) - return _render_admin_page("Websites", "Websites", body) + return _render_admin_page(request, "Websites", "Websites", body) result = await db.execute(select(Website).order_by(Website.id.asc())) websites = list(result.scalars().all()) @@ -1252,7 +1371,7 @@ async def websites_submit( websites, success=f"Website added successfully with ID {website.id}.", ) - return _render_admin_page("Websites", "Websites", body) + return _render_admin_page(request, "Websites", "Websites", body) @router.get("/websites/{website_id}/edit", include_in_schema=False) @@ -1270,10 +1389,10 @@ async def website_edit_view( result = await db.execute(select(Website).order_by(Website.id.asc())) websites = list(result.scalars().all()) body = _websites_form_body(websites, error=f"Website not found: {website_id}") - return _render_admin_page("Websites", "Websites", body) + return _render_admin_page(request, "Websites", "Websites", body) body = _website_edit_form_body(website) - return _render_admin_page("Edit Website", "Edit Website", body) + return _render_admin_page(request, "Edit Website", "Edit Website", body) @router.post("/websites/{website_id}/edit", include_in_schema=False) @@ -1293,7 +1412,7 @@ async def website_edit_submit( result = await db.execute(select(Website).order_by(Website.id.asc())) websites = list(result.scalars().all()) body = _websites_form_body(websites, error=f"Website not found: {website_id}") - return _render_admin_page("Websites", "Websites", body) + return _render_admin_page(request, "Websites", "Websites", body) normalized_name = site_name.strip() normalized_url = site_url.strip().rstrip("/") @@ -1305,7 +1424,7 @@ async def website_edit_submit( site_name=site_name, site_url=site_url, ) - return _render_admin_page("Edit Website", "Edit Website", body) + return _render_admin_page(request, "Edit Website", "Edit Website", body) if not normalized_url.startswith(("http://", "https://")): body = _website_edit_form_body( @@ -1314,7 +1433,7 @@ async def website_edit_submit( site_name=site_name, site_url=site_url, ) - return _render_admin_page("Edit Website", "Edit Website", body) + return _render_admin_page(request, "Edit Website", "Edit Website", body) website.site_name = normalized_name website.site_url = normalized_url @@ -1328,14 +1447,14 @@ async def website_edit_submit( site_name=site_name, site_url=site_url, ) - return _render_admin_page("Edit Website", "Edit Website", body) + return _render_admin_page(request, "Edit Website", "Edit Website", body) await db.refresh(website) body = _website_edit_form_body( website, success=f"Website #{website.id} updated successfully.", ) - return _render_admin_page("Edit Website", "Edit Website", body) + return _render_admin_page(request, "Edit Website", "Edit Website", body) @router.post("/websites/{website_id}/delete", include_in_schema=False) @@ -1353,7 +1472,7 @@ async def website_delete_submit( result = await db.execute(select(Website).order_by(Website.id.asc())) websites = list(result.scalars().all()) body = _websites_form_body(websites, error=f"Website not found: {website_id}") - return _render_admin_page("Websites", "Websites", body) + return _render_admin_page(request, "Websites", "Websites", body) deleted_label = f"{website.site_name} ({website.site_url})" await db.delete(website) @@ -1365,7 +1484,7 @@ async def website_delete_submit( websites, success=f"Website deleted successfully: {deleted_label}", ) - return _render_admin_page("Websites", "Websites", body) + return _render_admin_page(request, "Websites", "Websites", body) @router.get("/tryout-import", include_in_schema=False) @@ -1377,7 +1496,7 @@ async def tryout_import_view(request: Request, db: AsyncSession = Depends(get_db websites = await _load_websites(db) snapshots = await _recent_snapshots(db) body = _tryout_import_form_body(websites, snapshots) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) @router.post("/tryout-import/preview", include_in_schema=False) @@ -1401,7 +1520,7 @@ async def tryout_import_preview( error="File must be .json format.", selected_website_id=website_id, ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) try: payload_bytes = await file.read() @@ -1414,7 +1533,7 @@ async def tryout_import_preview( error="File must be UTF-8 encoded JSON.", selected_website_id=website_id, ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) except json.JSONDecodeError as exc: body = _tryout_import_form_body( websites, @@ -1422,7 +1541,7 @@ async def tryout_import_preview( error=f"Invalid JSON file: {exc}", selected_website_id=website_id, ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) try: preview = await preview_tryout_json_import(payload, website_id, db) @@ -1433,7 +1552,7 @@ async def tryout_import_preview( error=str(exc), selected_website_id=website_id, ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) preview_token = uuid.uuid4().hex await _admin_redis.set( @@ -1449,7 +1568,7 @@ async def tryout_import_preview( preview_token=preview_token, upload_filename=file.filename or "", ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) @router.post("/tryout-import", include_in_schema=False) @@ -1474,7 +1593,7 @@ async def tryout_import_submit( error="Preview token expired. Upload the JSON again and preview before importing.", selected_website_id=website_id, ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) try: payload = json.loads(payload_text) @@ -1488,7 +1607,7 @@ async def tryout_import_submit( error=str(exc), selected_website_id=website_id, ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) except Exception: await db.rollback() raise @@ -1506,7 +1625,7 @@ async def tryout_import_submit( ), selected_website_id=website_id, ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) @router.get("/snapshot-questions", include_in_schema=False) @@ -1528,11 +1647,11 @@ async def snapshot_questions_view( snapshots, error=f"Snapshot not found: {snapshot_id}", ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) questions, promoted_items_by_slot, _ = await _load_snapshot_question_context(snapshot, db) body = _snapshot_questions_body(snapshot, questions, promoted_items_by_slot) - return _render_admin_page("Snapshot Questions", "Snapshot Questions", body) + return _render_admin_page(request, "Snapshot Questions", "Snapshot Questions", body) @router.post("/snapshot-questions/promote-bulk", include_in_schema=False) @@ -1555,7 +1674,7 @@ async def snapshot_question_promote_bulk( snapshots, error=f"Snapshot not found: {snapshot_id}", ) - return _render_admin_page("Tryout Import", "Tryout Import", body) + return _render_admin_page(request, "Tryout Import", "Tryout Import", body) if not snapshot_question_ids: questions, promoted_items_by_slot, _ = await _load_snapshot_question_context(snapshot, db) @@ -1565,7 +1684,7 @@ async def snapshot_question_promote_bulk( promoted_items_by_slot, error="Select at least one snapshot question to promote.", ) - return _render_admin_page("Snapshot Questions", "Snapshot Questions", body) + return _render_admin_page(request, "Snapshot Questions", "Snapshot Questions", body) question_result = await db.execute( select(TryoutSnapshotQuestion).where( @@ -1607,7 +1726,7 @@ async def snapshot_question_promote_bulk( success_message += f" Latest basis item ID: {created_items[-1].id}." body = _snapshot_questions_body(snapshot, questions, promoted_items_by_slot, success=success_message) - return _render_admin_page("Snapshot Questions", "Snapshot Questions", body) + return _render_admin_page(request, "Snapshot Questions", "Snapshot Questions", body) @router.get("/calibration-status", include_in_schema=False) @@ -1637,7 +1756,7 @@ async def calibration_status_view(request: Request, db: AsyncSession = Depends(g ["Tryout ID", "Name", "Total Items", "Calibrated", "Calibration %", "Ready for IRT"], rows, ) - return _render_admin_page("Calibration Status", "Calibration Status", body) + return _render_admin_page(request, "Calibration Status", "Calibration Status", body) @router.get("/item-statistics", include_in_schema=False) @@ -1672,7 +1791,7 @@ async def item_statistics_view(request: Request, db: AsyncSession = Depends(get_ ["Level", "Total Items", "Calibrated", "Calibration %", "Responses", "Avg Correctness"], rows, ) - return _render_admin_page("Item Statistics", "Item Statistics", body) + return _render_admin_page(request, "Item Statistics", "Item Statistics", body) @router.get("/session-overview", include_in_schema=False) @@ -1702,7 +1821,7 @@ async def session_overview_view(request: Request, db: AsyncSession = Depends(get ["Session ID", "WP User", "Tryout", "Completed", "Mode", "Benar", "NM", "NN", "Theta"], rows, ) - return _render_admin_page("Session Overview", "Session Overview", body) + return _render_admin_page(request, "Session Overview", "Session Overview", body) @router.get("/basis-items", include_in_schema=False) @@ -1719,7 +1838,7 @@ async def basis_items_view(request: Request, db: AsyncSession = Depends(get_db)) ) basis_items = list(result.scalars().all()) body = _basis_items_list_body(basis_items) - return _render_admin_page("Basis Items", "Basis Items", body) + return _render_admin_page(request, "Basis Items", "Basis Items", body) @router.get("/basis-items/{basis_item_id}", include_in_schema=False) @@ -1752,7 +1871,7 @@ async def basis_item_workspace_view( .limit(200) ) body = _basis_items_list_body(list(result.scalars().all())) - return _render_admin_page("Basis Items", "Basis Items", body) + return _render_admin_page(request, "Basis Items", "Basis Items", body) run_result = await db.execute( select(AIGenerationRun) @@ -1794,7 +1913,7 @@ async def basis_item_workspace_view( family_stats, filters, ) - return _render_admin_page( + return _render_admin_page(request, f"Basis Item #{basis_item.id}", f"Basis Item Workspace #{basis_item.id}", body, @@ -1856,7 +1975,7 @@ async def basis_item_generate_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page( + return _render_admin_page(request, f"Basis Item #{basis_item.id}", f"Basis Item Workspace #{basis_item.id}", body, @@ -1951,7 +2070,7 @@ async def basis_item_generate_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page( + return _render_admin_page(request, f"Basis Item #{basis_item.id}", f"Basis Item Workspace #{basis_item.id}", body, @@ -2016,7 +2135,7 @@ async def basis_item_review_bulk( filters, success=f"Applied status '{action}' to selected variants.", ) - return _render_admin_page( + return _render_admin_page(request, f"Basis Item #{basis_item.id}", f"Basis Item Workspace #{basis_item.id}", body, @@ -2202,7 +2321,7 @@ async def ai_playground_view(request: Request, db: AsyncSession = Depends(get_db generated_variants=generated_variants, basis_item_id=str(basis_item_id or ""), ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) @router.post("/ai-playground/seed-demo", include_in_schema=False) @@ -2225,7 +2344,7 @@ async def ai_playground_seed_demo(request: Request, db: AsyncSession = Depends(g generated_variants=generated_variants, basis_item_id=str(demo_item.id), ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) @router.post("/ai-playground", include_in_schema=False) @@ -2268,7 +2387,7 @@ async def ai_playground_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) if target_level not in {"mudah", "sulit"}: body = _ai_form_body( @@ -2286,7 +2405,7 @@ async def ai_playground_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) if not validate_ai_model(ai_model): body = _ai_form_body( @@ -2304,7 +2423,7 @@ async def ai_playground_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) result = await db.execute(select(Item).where(Item.id == basis_item_id)) basis_item = result.scalar_one_or_none() @@ -2324,7 +2443,7 @@ async def ai_playground_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) if basis_item.level != "sedang": body = _ai_form_body( @@ -2342,7 +2461,7 @@ async def ai_playground_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) if generation_count < 1 or generation_count > 50: body = _ai_form_body( @@ -2360,7 +2479,7 @@ async def ai_playground_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) run_id = await create_generation_run( basis_item_id=basis_item.id, @@ -2428,7 +2547,7 @@ async def ai_playground_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) body = _ai_form_body( True, @@ -2451,7 +2570,7 @@ async def ai_playground_submit( include_note_for_admin=note_for_admin, include_note_in_prompt=note_in_prompt, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) @router.post("/ai-playground/save", include_in_schema=False) @@ -2483,7 +2602,7 @@ async def ai_playground_save( error="Only mudah or sulit generated items can be saved from the playground.", basis_items=basis_items, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) try: options = json.loads(options_json) @@ -2494,7 +2613,7 @@ async def ai_playground_save( error="Generated options payload is invalid.", basis_items=basis_items, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) from app.schemas.ai import GeneratedQuestion @@ -2521,7 +2640,7 @@ async def ai_playground_save( error="Failed to save generated item. Check server logs for the database error.", basis_items=basis_items, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) await db.commit() updated_stats = await get_ai_stats(db) @@ -2535,7 +2654,7 @@ async def ai_playground_save( target_level=target_level, ai_model=ai_model, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) @router.post("/ai-playground/review-bulk", include_in_schema=False) @@ -2564,7 +2683,7 @@ async def ai_playground_review_bulk( generation_runs=generation_runs, generated_variants=generated_variants, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) if not item_ids: body = _ai_form_body( @@ -2575,7 +2694,7 @@ async def ai_playground_review_bulk( generation_runs=generation_runs, generated_variants=generated_variants, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) result = await db.execute( select(Item).where(Item.id.in_(item_ids), Item.generated_by == "ai") @@ -2590,7 +2709,7 @@ async def ai_playground_review_bulk( generation_runs=generation_runs, generated_variants=generated_variants, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) reviewed_at = datetime.now(timezone.utc) for item in items: @@ -2612,7 +2731,7 @@ async def ai_playground_review_bulk( generation_runs=updated_runs, generated_variants=updated_variants, ) - return _render_admin_page("AI Playground", "AI Playground", body) + return _render_admin_page(request, "AI Playground", "AI Playground", body) @router.get("/tryout/list", include_in_schema=False) diff --git a/app/api/v1/session.py b/app/api/v1/session.py index 9da4945..e3deca4 100644 --- a/app/api/v1/session.py +++ b/app/api/v1/session.py @@ -14,6 +14,12 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db +from app.core.auth import ( + AuthContext, + ensure_website_scope_matches, + get_auth_context, + require_website_auth, +) from app.models import Item, Session, Tryout from app.services.cat_selection import ( CATSelectionError, @@ -106,7 +112,8 @@ class CATTestResponse(BaseModel): ) async def get_next_item_endpoint( session_id: str, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + auth: AuthContext = Depends(get_auth_context), ) -> NextItemResponse: """ Get the next item for a session. @@ -116,8 +123,13 @@ async def get_next_item_endpoint( Calls appropriate selection function based on selection_mode. Returns item or completion status. """ + website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"}) + # Get session - session_query = select(Session).where(Session.session_id == session_id) + session_query = select(Session).where( + Session.session_id == session_id, + Session.website_id == website_id, + ) session_result = await db.execute(session_query) session = session_result.scalar_one_or_none() @@ -126,6 +138,11 @@ async def get_next_item_endpoint( status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found" ) + if auth.role == "student" and session.wp_user_id != auth.wp_user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Session does not belong to this authenticated user", + ) if session.is_completed: return NextItemResponse( @@ -214,7 +231,8 @@ async def get_next_item_endpoint( async def submit_answer_endpoint( session_id: str, request: SubmitAnswerRequest, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + auth: AuthContext = Depends(get_auth_context), ) -> SubmitAnswerResponse: """ Submit an answer for an item. @@ -224,8 +242,13 @@ async def submit_answer_endpoint( Updates theta estimate. Records response time. """ + website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"}) + # Get session - session_query = select(Session).where(Session.session_id == session_id) + session_query = select(Session).where( + Session.session_id == session_id, + Session.website_id == website_id, + ) session_result = await db.execute(session_query) session = session_result.scalar_one_or_none() @@ -234,6 +257,11 @@ async def submit_answer_endpoint( status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found" ) + if auth.role == "student" and session.wp_user_id != auth.wp_user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Session does not belong to this authenticated user", + ) if session.is_completed: raise HTTPException( @@ -242,7 +270,11 @@ async def submit_answer_endpoint( ) # Get item - item_query = select(Item).where(Item.id == request.item_id) + item_query = select(Item).where( + Item.id == request.item_id, + Item.website_id == session.website_id, + Item.tryout_id == session.tryout_id, + ) item_result = await db.execute(item_query) item = item_result.scalar_one_or_none() @@ -296,7 +328,8 @@ async def submit_answer_endpoint( ) async def test_cat_endpoint( request: CATTestRequest, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + auth: AuthContext = Depends(get_auth_context), ) -> CATTestResponse: """ Test CAT selection algorithm. @@ -304,10 +337,13 @@ async def test_cat_endpoint( Simulates CAT selection for a tryout and returns the sequence of selected items with theta progression. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + ensure_website_scope_matches(website_id, request.website_id) + # Verify tryout exists tryout_query = select(Tryout).where( Tryout.tryout_id == request.tryout_id, - Tryout.website_id == request.website_id + Tryout.website_id == website_id ) tryout_result = await db.execute(tryout_query) tryout = tryout_result.scalar_one_or_none() @@ -315,14 +351,14 @@ async def test_cat_endpoint( if not tryout: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Tryout {request.tryout_id} not found for website {request.website_id}" + detail=f"Tryout {request.tryout_id} not found for website {website_id}" ) # Run simulation result = await simulate_cat_selection( db, tryout_id=request.tryout_id, - website_id=request.website_id, + website_id=website_id, initial_theta=request.initial_theta, selection_mode=request.selection_mode, max_items=request.max_items, @@ -346,13 +382,19 @@ async def test_cat_endpoint( ) async def get_session_status_endpoint( session_id: str, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + auth: AuthContext = Depends(get_auth_context), ) -> dict: """ Get session status for admin monitoring. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + # Get session - session_query = select(Session).where(Session.session_id == session_id) + session_query = select(Session).where( + Session.session_id == session_id, + Session.website_id == website_id, + ) session_result = await db.execute(session_query) session = session_result.scalar_one_or_none() diff --git a/app/core/auth.py b/app/core/auth.py new file mode 100644 index 0000000..a39245f --- /dev/null +++ b/app/core/auth.py @@ -0,0 +1,144 @@ +""" +Token-based authentication helpers for website-scoped access control. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +import time +from dataclasses import dataclass +from typing import Optional + +from fastapi import Header, HTTPException, status + +from app.core.config import get_settings + +settings = get_settings() + + +@dataclass +class AuthContext: + website_id: int + role: str + wp_user_id: Optional[str] = None + + +def _b64url_encode(raw: bytes) -> str: + return base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + + +def _b64url_decode(raw: str) -> bytes: + padding = "=" * (-len(raw) % 4) + return base64.urlsafe_b64decode((raw + padding).encode("ascii")) + + +def issue_access_token( + website_id: int, + role: str = "student", + wp_user_id: str | None = None, + expires_in_seconds: int = 3600, +) -> str: + payload = { + "website_id": int(website_id), + "role": role, + "wp_user_id": wp_user_id, + "exp": int(time.time()) + int(expires_in_seconds), + } + payload_bytes = json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8") + payload_b64 = _b64url_encode(payload_bytes) + sig = hmac.new(settings.SECRET_KEY.encode("utf-8"), payload_b64.encode("ascii"), hashlib.sha256).digest() + return f"{payload_b64}.{_b64url_encode(sig)}" + + +def decode_access_token(token: str) -> AuthContext: + try: + payload_b64, sig_b64 = token.split(".", 1) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid access token format", + ) from exc + + expected_sig = hmac.new( + settings.SECRET_KEY.encode("utf-8"), + payload_b64.encode("ascii"), + hashlib.sha256, + ).digest() + provided_sig = _b64url_decode(sig_b64) + if not hmac.compare_digest(provided_sig, expected_sig): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid access token signature", + ) + + try: + payload = json.loads(_b64url_decode(payload_b64).decode("utf-8")) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid access token payload", + ) from exc + + exp = int(payload.get("exp", 0)) + if exp <= int(time.time()): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Access token has expired", + ) + + website_id = payload.get("website_id") + role = payload.get("role") + if website_id is None or not role: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Access token missing required claims", + ) + + return AuthContext( + website_id=int(website_id), + role=str(role), + wp_user_id=payload.get("wp_user_id"), + ) + + +def get_auth_context( + authorization: str | None = Header(None, alias="Authorization"), +) -> AuthContext: + if authorization is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authorization header is required", + ) + parts = authorization.split() + if len(parts) != 2 or parts[0].lower() != "bearer": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid Authorization header format. Use: Bearer {token}", + ) + return decode_access_token(parts[1]) + + +def require_website_auth( + auth: AuthContext, + allowed_roles: set[str] | None = None, +) -> int: + if allowed_roles is not None and auth.role not in allowed_roles: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Insufficient permissions for this endpoint", + ) + return auth.website_id + + +def ensure_website_scope_matches( + auth_website_id: int, + payload_website_id: int, +) -> None: + if int(auth_website_id) != int(payload_website_id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="website_id in payload must match authenticated website scope", + ) diff --git a/app/core/rate_limit.py b/app/core/rate_limit.py new file mode 100644 index 0000000..3656ca6 --- /dev/null +++ b/app/core/rate_limit.py @@ -0,0 +1,45 @@ +""" +Lightweight in-process rate limiting helpers. +""" + +from __future__ import annotations + +import threading +import time +from collections import defaultdict, deque + +from fastapi import HTTPException, Request, status + +_lock = threading.Lock() +_hits: dict[str, deque[float]] = defaultdict(deque) + + +def _client_ip(request: Request) -> str: + if request.client and request.client.host: + return request.client.host + return "unknown" + + +def enforce_rate_limit( + request: Request, + *, + scope: str, + max_requests: int, + window_seconds: int, +) -> None: + now = time.time() + ip = _client_ip(request) + key = f"{scope}:{ip}" + cutoff = now - window_seconds + + with _lock: + dq = _hits[key] + while dq and dq[0] <= cutoff: + dq.popleft() + if len(dq) >= max_requests: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail=f"Too many requests for {scope}. Please try again later.", + ) + dq.append(now) + diff --git a/app/main.py b/app/main.py index 4a98021..2566c84 100644 --- a/app/main.py +++ b/app/main.py @@ -40,6 +40,33 @@ from app.routers import ( settings = get_settings() +def validate_security_config() -> None: + """ + Enforce minimum security requirements for production deployments. + """ + if settings.ENVIRONMENT != "production": + return + + insecure_secret_values = { + "", + "dev-secret-key-change-in-production", + "your-secret-key-here-change-in-production", + } + if settings.SECRET_KEY in insecure_secret_values: + raise RuntimeError( + "In production, SECRET_KEY must be set to a strong non-default value." + ) + + if settings.ENABLE_ADMIN and ( + not settings.ADMIN_USERNAME + or not settings.ADMIN_PASSWORD + or settings.ADMIN_PASSWORD == "change-me" + ): + raise RuntimeError( + "In production with ENABLE_ADMIN=true, ADMIN_USERNAME and ADMIN_PASSWORD must be configured securely." + ) + + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """ @@ -47,6 +74,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: Handles startup and shutdown events. """ + validate_security_config() + # Startup: Initialize database await init_db() if settings.ENABLE_ADMIN: diff --git a/app/routers/ai.py b/app/routers/ai.py index f65536a..093e64b 100644 --- a/app/routers/ai.py +++ b/app/routers/ai.py @@ -7,11 +7,18 @@ Admin endpoints for AI question generation playground. import logging from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import get_settings +from app.core.auth import ( + AuthContext, + ensure_website_scope_matches, + get_auth_context, + require_website_auth, +) +from app.core.rate_limit import enforce_rate_limit from app.database import get_db from app.models.item import Item from app.schemas.ai import ( @@ -58,8 +65,10 @@ router = APIRouter(prefix="/admin/ai", tags=["admin", "ai-generation"]) }, ) async def generate_preview( + request_http: Request, request: AIGeneratePreviewRequest, db: Annotated[AsyncSession, Depends(get_db)], + auth: AuthContext = Depends(get_auth_context), ) -> AIGeneratePreviewResponse: """ Generate AI question preview (no database save). @@ -68,6 +77,14 @@ async def generate_preview( - **target_level**: Target difficulty (mudah/sulit) - **ai_model**: OpenRouter model to use (default: qwen/qwen2.5-32b-instruct) """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + enforce_rate_limit( + request_http, + scope="ai.generate_preview", + max_requests=40, + window_seconds=300, + ) + # Validate AI model if not validate_ai_model(request.ai_model): supported = ", ".join(SUPPORTED_MODELS.keys()) @@ -88,6 +105,7 @@ async def generate_preview( status_code=status.HTTP_404_NOT_FOUND, detail=f"Basis item not found: {request.basis_item_id}", ) + ensure_website_scope_matches(website_id, basis_item.website_id) # Validate basis item is sedang level if basis_item.level != "sedang": @@ -158,8 +176,10 @@ async def generate_preview( }, ) async def generate_save( + request_http: Request, request: AISaveRequest, db: Annotated[AsyncSession, Depends(get_db)], + auth: AuthContext = Depends(get_auth_context), ) -> AISaveResponse: """ Save AI-generated question to database. @@ -175,6 +195,15 @@ async def generate_save( - **level**: Difficulty level - **ai_model**: AI model used for generation """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + enforce_rate_limit( + request_http, + scope="ai.generate_save", + max_requests=40, + window_seconds=300, + ) + ensure_website_scope_matches(website_id, request.website_id) + # Verify basis item exists basis_result = await db.execute( select(Item).where(Item.id == request.basis_item_id) @@ -186,6 +215,7 @@ async def generate_save( status_code=status.HTTP_404_NOT_FOUND, detail=f"Basis item not found: {request.basis_item_id}", ) + ensure_website_scope_matches(website_id, basis_item.website_id) # Check for duplicate (same tryout, website, slot, level) existing_result = await db.execute( @@ -256,10 +286,12 @@ async def generate_save( ) async def get_stats( db: Annotated[AsyncSession, Depends(get_db)], + auth: AuthContext = Depends(get_auth_context), ) -> AIStatsResponse: """ Get AI generation statistics. """ + require_website_auth(auth, allowed_roles={"admin", "system_admin"}) stats = await get_ai_stats(db) return AIStatsResponse( @@ -276,10 +308,11 @@ async def get_stats( summary="List supported AI models", description="Returns list of supported AI models for question generation.", ) -async def list_models() -> dict: +async def list_models(auth: AuthContext = Depends(get_auth_context)) -> dict: """ List supported AI models. """ + require_website_auth(auth, allowed_roles={"admin", "system_admin"}) return { "models": [ { diff --git a/app/routers/import_export.py b/app/routers/import_export.py index a579c40..42210f6 100644 --- a/app/routers/import_export.py +++ b/app/routers/import_export.py @@ -12,12 +12,12 @@ Endpoints: import os import tempfile import json -from typing import Optional - -from fastapi import APIRouter, Depends, File, Form, Header, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse from sqlalchemy.ext.asyncio import AsyncSession +from app.core.auth import AuthContext, get_auth_context, require_website_auth +from app.core.rate_limit import enforce_rate_limit from app.database import get_db from app.models import Website from app.services.excel_import import ( @@ -35,35 +35,6 @@ from app.services.tryout_json_import import ( router = APIRouter(prefix="/api/v1/import-export", tags=["import-export"]) -def get_website_id_from_header( - x_website_id: Optional[str] = Header(None, alias="X-Website-ID"), -) -> int: - """ - Extract and validate website_id from request header. - - Args: - x_website_id: Website ID from header - - Returns: - Validated website ID as integer - - Raises: - HTTPException: If header is missing or invalid - """ - if x_website_id is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="X-Website-ID header is required", - ) - try: - return int(x_website_id) - except ValueError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="X-Website-ID must be a valid integer", - ) - - async def ensure_website_exists( website_id: int, db: AsyncSession, @@ -85,8 +56,9 @@ async def ensure_website_exists( description="Parse Excel file and return preview without saving to database.", ) async def preview_import( + request: Request, file: UploadFile = File(..., description="Excel file (.xlsx)"), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> dict: """ Preview Excel import without saving to database. @@ -104,6 +76,14 @@ async def preview_import( Raises: HTTPException: If file format is invalid or parsing fails """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + enforce_rate_limit( + request, + scope="import.preview", + max_requests=30, + window_seconds=300, + ) + # Validate file format if not file.filename or not file.filename.lower().endswith('.xlsx'): raise HTTPException( @@ -173,8 +153,9 @@ async def preview_import( description="Parse Excel file and import questions to database with 100% data integrity.", ) async def import_questions( + request: Request, file: UploadFile = File(..., description="Excel file (.xlsx)"), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), tryout_id: str = Form(..., description="Tryout identifier"), db: AsyncSession = Depends(get_db), ) -> dict: @@ -199,6 +180,14 @@ async def import_questions( Raises: HTTPException: If file format is invalid, validation fails, or import fails """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + enforce_rate_limit( + request, + scope="import.questions", + max_requests=20, + window_seconds=300, + ) + # Validate file format if not file.filename or not file.filename.lower().endswith('.xlsx'): raise HTTPException( @@ -297,7 +286,7 @@ async def import_questions( ) async def export_questions( tryout_id: str, - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), db: AsyncSession = Depends(get_db), ) -> FileResponse: """ @@ -320,6 +309,8 @@ async def export_questions( Raises: HTTPException: If tryout has no questions or export fails """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + try: # Export questions to Excel output_path = await export_questions_to_excel( @@ -354,10 +345,18 @@ async def export_questions( description="Parse a Sejoli tryout export JSON file and show snapshot diff without writing to database.", ) async def preview_tryout_json( + request: Request, file: UploadFile = File(..., description="Sejoli tryout export JSON"), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), db: AsyncSession = Depends(get_db), ) -> dict: + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + enforce_rate_limit( + request, + scope="import.tryout_json_preview", + max_requests=30, + window_seconds=300, + ) if not file.filename or not file.filename.lower().endswith(".json"): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -389,10 +388,18 @@ async def preview_tryout_json( description="Store Sejoli tryout export JSON as read-only snapshot data and upsert normalized reference questions.", ) async def import_tryout_json( + request: Request, file: UploadFile = File(..., description="Sejoli tryout export JSON"), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), db: AsyncSession = Depends(get_db), ) -> dict: + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + enforce_rate_limit( + request, + scope="import.tryout_json", + max_requests=20, + window_seconds=300, + ) if not file.filename or not file.filename.lower().endswith(".json"): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/app/routers/reports.py b/app/routers/reports.py index 3e31815..45df3fd 100644 --- a/app/routers/reports.py +++ b/app/routers/reports.py @@ -14,11 +14,17 @@ import os from datetime import datetime from typing import List, Literal, Optional -from fastapi import APIRouter, Depends, HTTPException, Header, status +from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import FileResponse from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db +from app.core.auth import ( + AuthContext, + ensure_website_scope_matches, + get_auth_context, + require_website_auth, +) from app.schemas.report import ( StudentPerformanceReportOutput, AggregatePerformanceStatsOutput, @@ -55,35 +61,6 @@ from app.services.reporting import ( router = APIRouter(prefix="/reports", tags=["reports"]) -def get_website_id_from_header( - x_website_id: Optional[str] = Header(None, alias="X-Website-ID"), -) -> int: - """ - Extract and validate website_id from request header. - - Args: - x_website_id: Website ID from header - - Returns: - Validated website ID as integer - - Raises: - HTTPException: If header is missing or invalid - """ - if x_website_id is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="X-Website-ID header is required", - ) - try: - return int(x_website_id) - except ValueError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="X-Website-ID must be a valid integer", - ) - - # ============================================================================= # Student Performance Report Endpoints # ============================================================================= @@ -97,7 +74,7 @@ def get_website_id_from_header( async def get_student_performance_report( tryout_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), date_start: Optional[datetime] = None, date_end: Optional[datetime] = None, format_type: Literal["individual", "aggregate", "both"] = "both", @@ -107,6 +84,7 @@ async def get_student_performance_report( Returns individual student records and/or aggregate statistics. """ + website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"}) date_range = None if date_start or date_end: date_range = {} @@ -190,7 +168,7 @@ def _convert_student_performance_report(report: StudentPerformanceReport) -> Stu async def get_item_analysis_report( tryout_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), filter_by: Optional[Literal["difficulty", "calibrated", "discrimination"]] = None, difficulty_level: Optional[Literal["mudah", "sedang", "sulit"]] = None, ) -> ItemAnalysisReportOutput: @@ -199,6 +177,7 @@ async def get_item_analysis_report( Returns item difficulty, discrimination, and information function data. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) report = await generate_item_analysis_report( tryout_id=tryout_id, website_id=website_id, @@ -248,13 +227,14 @@ async def get_item_analysis_report( async def get_calibration_status_report( tryout_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> CalibrationStatusReportOutput: """ Get calibration status report. Returns calibration progress, items awaiting calibration, and IRT readiness status. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) report = await generate_calibration_status_report( tryout_id=tryout_id, website_id=website_id, @@ -313,7 +293,7 @@ async def get_calibration_status_report( async def get_tryout_comparison_report( tryout_ids: str, # Comma-separated list db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), group_by: Literal["date", "subject"] = "date", ) -> TryoutComparisonReportOutput: """ @@ -321,6 +301,7 @@ async def get_tryout_comparison_report( Compares tryouts across dates or subjects. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) tryout_id_list = [tid.strip() for tid in tryout_ids.split(",")] if len(tryout_id_list) < 2: @@ -371,12 +352,15 @@ async def get_tryout_comparison_report( async def create_report_schedule( request: ReportScheduleRequest, db: AsyncSession = Depends(get_db), + auth: AuthContext = Depends(get_auth_context), ) -> ReportScheduleResponse: """ Schedule a report. Creates a scheduled report that will be generated automatically. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + ensure_website_scope_matches(website_id, request.website_id) schedule_id = schedule_report( report_type=request.report_type, schedule=request.schedule, @@ -403,13 +387,14 @@ async def create_report_schedule( ) async def get_scheduled_report_details( schedule_id: str, - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> ReportScheduleOutput: """ Get scheduled report details. Returns the configuration and status of a scheduled report. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) scheduled = get_scheduled_report(schedule_id) if not scheduled: @@ -446,13 +431,14 @@ async def get_scheduled_report_details( description="List all scheduled reports for a website.", ) async def list_scheduled_reports_endpoint( - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> List[ReportScheduleOutput]: """ List all scheduled reports. Returns all scheduled reports for the current website. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) reports = list_scheduled_reports(website_id=website_id) return [ @@ -480,13 +466,14 @@ async def list_scheduled_reports_endpoint( ) async def cancel_scheduled_report_endpoint( schedule_id: str, - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> dict: """ Cancel a scheduled report. Removes the scheduled report from the system. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) scheduled = get_scheduled_report(schedule_id) if not scheduled: @@ -528,13 +515,14 @@ async def export_scheduled_report( schedule_id: str, format: Literal["csv", "xlsx", "pdf"], db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ): """ Export a scheduled report. Generates the report and returns it as a file download. """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) scheduled = get_scheduled_report(schedule_id) if not scheduled: @@ -628,11 +616,12 @@ async def export_student_performance_direct( format: Literal["csv", "xlsx", "pdf"], tryout_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), date_start: Optional[datetime] = None, date_end: Optional[datetime] = None, ): """Export student performance report directly.""" + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) date_range = None if date_start or date_end: date_range = {} @@ -676,11 +665,12 @@ async def export_item_analysis_direct( format: Literal["csv", "xlsx", "pdf"], tryout_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), filter_by: Optional[Literal["difficulty", "calibrated", "discrimination"]] = None, difficulty_level: Optional[Literal["mudah", "sedang", "sulit"]] = None, ): """Export item analysis report directly.""" + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) report = await generate_item_analysis_report( tryout_id=tryout_id, website_id=website_id, @@ -717,9 +707,10 @@ async def export_calibration_status_direct( format: Literal["csv", "xlsx", "pdf"], tryout_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ): """Export calibration status report directly.""" + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) report = await generate_calibration_status_report( tryout_id=tryout_id, website_id=website_id, @@ -754,10 +745,11 @@ async def export_tryout_comparison_direct( format: Literal["csv", "xlsx", "pdf"], tryout_ids: str, # Comma-separated db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), group_by: Literal["date", "subject"] = "date", ): """Export tryout comparison report directly.""" + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) tryout_id_list = [tid.strip() for tid in tryout_ids.split(",")] if len(tryout_id_list) < 2: diff --git a/app/routers/sessions.py b/app/routers/sessions.py index 56c6488..41217fa 100644 --- a/app/routers/sessions.py +++ b/app/routers/sessions.py @@ -8,14 +8,18 @@ Endpoints: """ from datetime import datetime, timezone -from typing import Optional - -from fastapi import APIRouter, Depends, HTTPException, Header, status +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.database import get_db +from app.core.auth import ( + AuthContext, + ensure_website_scope_matches, + get_auth_context, + require_website_auth, +) from app.models.item import Item from app.models.session import Session from app.models.tryout import Tryout @@ -39,35 +43,6 @@ from app.services.ctt_scoring import ( router = APIRouter(prefix="/session", tags=["sessions"]) -def get_website_id_from_header( - x_website_id: Optional[str] = Header(None, alias="X-Website-ID"), -) -> int: - """ - Extract and validate website_id from request header. - - Args: - x_website_id: Website ID from header - - Returns: - Validated website ID as integer - - Raises: - HTTPException: If header is missing or invalid - """ - if x_website_id is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="X-Website-ID header is required", - ) - try: - return int(x_website_id) - except ValueError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="X-Website-ID must be a valid integer", - ) - - @router.post( "/{session_id}/complete", response_model=SessionCompleteResponse, @@ -78,7 +53,7 @@ async def complete_session( session_id: str, request: SessionCompleteRequest, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> SessionCompleteResponse: """ Complete a session by submitting answers and calculating CTT scores. @@ -104,6 +79,8 @@ async def complete_session( Raises: HTTPException: If session not found, already completed, or validation fails """ + website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"}) + # Get session with tryout relationship result = await db.execute( select(Session) @@ -126,6 +103,11 @@ async def complete_session( status_code=status.HTTP_400_BAD_REQUEST, detail="Session is already completed", ) + if auth.role == "student" and session.wp_user_id != auth.wp_user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Session does not belong to this authenticated user", + ) # Get tryout configuration tryout = session.tryout @@ -298,7 +280,7 @@ async def complete_session( async def get_session( session_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> SessionResponse: """ Get session details. @@ -314,6 +296,8 @@ async def get_session( Raises: HTTPException: If session not found """ + website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"}) + result = await db.execute( select(Session).where( Session.session_id == session_id, @@ -327,6 +311,11 @@ async def get_session( status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found", ) + if auth.role == "student" and session.wp_user_id != auth.wp_user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Session does not belong to this authenticated user", + ) return SessionResponse.model_validate(session) @@ -341,7 +330,7 @@ async def get_session( async def create_session( request: SessionCreateRequest, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> SessionResponse: """ Create a new session. @@ -356,13 +345,13 @@ async def create_session( Raises: HTTPException: If tryout not found or session already exists """ - if request.website_id != website_id: + website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"}) + + ensure_website_scope_matches(website_id, request.website_id) + if auth.role == "student" and request.wp_user_id != auth.wp_user_id: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=( - "Website mismatch between payload and X-Website-ID header: " - f"{request.website_id} != {website_id}" - ), + status_code=status.HTTP_403_FORBIDDEN, + detail="wp_user_id must match authenticated user", ) # Verify tryout exists diff --git a/app/routers/tryouts.py b/app/routers/tryouts.py index b7be997..353bda5 100644 --- a/app/routers/tryouts.py +++ b/app/routers/tryouts.py @@ -7,14 +7,15 @@ Endpoints: - GET /tryout: List tryouts for a website """ -from typing import List, Optional +from typing import List -from fastapi import APIRouter, Depends, HTTPException, Header, status +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import Integer, cast, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.database import get_db +from app.core.auth import AuthContext, get_auth_context, require_website_auth from app.models.item import Item from app.models.tryout import Tryout from app.models.tryout_stats import TryoutStats @@ -29,35 +30,6 @@ from app.schemas.tryout import ( router = APIRouter(prefix="/tryout", tags=["tryouts"]) -def get_website_id_from_header( - x_website_id: Optional[str] = Header(None, alias="X-Website-ID"), -) -> int: - """ - Extract and validate website_id from request header. - - Args: - x_website_id: Website ID from header - - Returns: - Validated website ID as integer - - Raises: - HTTPException: If header is missing or invalid - """ - if x_website_id is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="X-Website-ID header is required", - ) - try: - return int(x_website_id) - except ValueError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="X-Website-ID must be a valid integer", - ) - - @router.get( "/{tryout_id}/config", response_model=TryoutConfigResponse, @@ -67,7 +39,7 @@ def get_website_id_from_header( async def get_tryout_config( tryout_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> TryoutConfigResponse: """ Get tryout configuration. @@ -78,6 +50,8 @@ async def get_tryout_config( Raises: HTTPException: If tryout not found """ + website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"}) + # Get tryout with stats result = await db.execute( select(Tryout) @@ -140,7 +114,7 @@ async def update_normalization( tryout_id: str, request: NormalizationUpdateRequest, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> NormalizationUpdateResponse: """ Update normalization settings for a tryout. @@ -157,6 +131,8 @@ async def update_normalization( Raises: HTTPException: If tryout not found or validation fails """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + # Get tryout result = await db.execute( select(Tryout).where( @@ -214,7 +190,7 @@ async def update_normalization( ) async def list_tryouts( db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ) -> List[TryoutConfigBrief]: """ List all tryouts for a website. @@ -226,6 +202,8 @@ async def list_tryouts( Returns: List of TryoutConfigBrief """ + website_id = require_website_auth(auth, allowed_roles={"student", "admin", "system_admin"}) + # Get tryouts with stats result = await db.execute( select(Tryout) @@ -255,7 +233,7 @@ async def list_tryouts( async def get_calibration_status( tryout_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ): """ Get calibration status for items in a tryout. @@ -273,6 +251,8 @@ async def get_calibration_status( Raises: HTTPException: If tryout not found """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + # Verify tryout exists tryout_result = await db.execute( select(Tryout).where( @@ -324,7 +304,7 @@ async def get_calibration_status( async def trigger_calibration( tryout_id: str, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ): """ Trigger IRT calibration for all items in a tryout. @@ -343,6 +323,8 @@ async def trigger_calibration( Raises: HTTPException: If tryout not found or calibration fails """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + from app.services.irt_calibration import ( calibrate_all, CALIBRATION_SAMPLE_THRESHOLD, @@ -391,7 +373,7 @@ async def trigger_item_calibration( tryout_id: str, item_id: int, db: AsyncSession = Depends(get_db), - website_id: int = Depends(get_website_id_from_header), + auth: AuthContext = Depends(get_auth_context), ): """ Trigger IRT calibration for a single item. @@ -408,6 +390,8 @@ async def trigger_item_calibration( Raises: HTTPException: If tryout or item not found """ + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + from app.services.irt_calibration import calibrate_item, CALIBRATION_SAMPLE_THRESHOLD # Verify tryout exists diff --git a/app/routers/wordpress.py b/app/routers/wordpress.py index d1cfb21..bc59c24 100644 --- a/app/routers/wordpress.py +++ b/app/routers/wordpress.py @@ -10,11 +10,12 @@ Endpoints: import logging from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Header, status +from fastapi import APIRouter, Depends, HTTPException, Header, Request, status from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_db +from app.core.auth import issue_access_token from app.models.user import User from app.models.website import Website from app.schemas.wordpress import ( @@ -36,6 +37,7 @@ from app.services.wordpress_auth import ( WordPressTokenInvalidError, WebsiteNotFoundError, ) +from app.core.rate_limit import enforce_rate_limit logger = logging.getLogger(__name__) @@ -104,6 +106,7 @@ async def get_valid_website( description="Fetch all users from WordPress API and sync to local database. Requires admin WordPress token.", ) async def sync_users_endpoint( + request: Request, db: AsyncSession = Depends(get_db), website_id: int = Depends(get_website_id_from_header), authorization: Optional[str] = Header(None, alias="Authorization"), @@ -129,6 +132,13 @@ async def sync_users_endpoint( Raises: HTTPException: If website not found, token invalid, or API error """ + enforce_rate_limit( + request, + scope="wordpress.sync_users", + max_requests=20, + window_seconds=300, + ) + # Validate website exists await get_valid_website(website_id, db) @@ -196,6 +206,7 @@ async def sync_users_endpoint( description="Verify WordPress JWT token and user identity.", ) async def verify_session_endpoint( + http_request: Request, request: VerifySessionRequest, db: AsyncSession = Depends(get_db), ) -> VerifySessionResponse: @@ -219,6 +230,13 @@ async def verify_session_endpoint( Raises: HTTPException: If website not found or API error """ + enforce_rate_limit( + http_request, + scope="wordpress.verify_session", + max_requests=60, + window_seconds=300, + ) + # Validate website exists await get_valid_website(request.website_id, db) @@ -253,6 +271,12 @@ async def verify_session_endpoint( "display_name": wp_user_info.display_name, "roles": wp_user_info.roles, }, + access_token=issue_access_token( + website_id=request.website_id, + role="student", + wp_user_id=request.wp_user_id, + expires_in_seconds=3600 * 24, + ), ) except WordPressTokenInvalidError as e: diff --git a/app/schemas/wordpress.py b/app/schemas/wordpress.py index eb6f2c1..a0a89c4 100644 --- a/app/schemas/wordpress.py +++ b/app/schemas/wordpress.py @@ -41,6 +41,10 @@ class VerifySessionResponse(BaseModel): wp_user_info: Optional[dict[str, Any]] = Field( default=None, description="WordPress user info from API" ) + access_token: Optional[str] = Field( + default=None, + description="Signed API access token for authenticated website-scoped calls", + ) class SyncUsersRequest(BaseModel): diff --git a/hands-off.md b/hands-off.md new file mode 100644 index 0000000..6bb70b5 --- /dev/null +++ b/hands-off.md @@ -0,0 +1,137 @@ +# Yellow Bank Soal Perfection Tasklist + +Date: 2026-04-29 +Purpose: hands-off development guide for hardening the system, improving correctness, and polishing the admin/user experience. + +## 1. Security and Auth + +- [x] Add centralized authentication dependencies for student, website admin, and system admin roles. +- [x] Replace raw `X-Website-ID` trust with token-derived website access. +- [x] Require authorization on reports, tryout configuration updates, imports, calibration, and session endpoints. +- [x] Add session ownership checks using verified WordPress identity. +- [x] Add rate limiting for admin login, AI generation, imports, and WordPress verification. +- [x] Add admin login rate limiting (IP-based, Redis-backed attempt window). +- [x] Add CSRF tokens to all admin POST forms. +- [x] Mark admin session cookies `secure` in production. +- [x] Fail production startup when default or empty secrets are used. +- [x] Add tests proving cross-website access is blocked. +- [x] Add token integrity tests (issue/decode, tamper rejection, expiry rejection). + +## 2. Session Integrity + +- [ ] Verify `submit_answer` item belongs to the session's `website_id` and `tryout_id`. +- [ ] Prevent answer submission for items not issued by `next_item`. +- [ ] Stop returning `correct_answer` during live adaptive sessions. +- [ ] Decide whether explanations should be shown only after completion or never during an active session. +- [ ] Add duplicate-answer validation before DB commit. +- [ ] Make repeated submissions return `409 Conflict` instead of DB errors. +- [ ] Validate or auto-create WordPress users before creating sessions. +- [ ] Add tests for invalid item IDs, foreign-tryout items, repeated answers, and completed sessions. + +## 3. Scoring Correctness + +- [ ] Revisit CTT `total_bobot_max` logic so earned and max weights use the same item set. +- [ ] Define scoring behavior for mixed-level tryouts. +- [ ] Confirm whether fixed tryouts should require every item to be answered before completion. +- [ ] Add tests for all-correct, all-wrong, partial, mixed-level, missing-bobot, and duplicate-answer cases. +- [ ] Add regression tests for static, dynamic, and hybrid normalization switching. +- [ ] Confirm NM, NN, theta, and report formulas against PRD examples. +- [ ] Add explicit handling for zero/near-zero standard deviation in reporting and normalization. + +## 4. Database and Migrations + +- [ ] Resolve model/migration drift for item uniqueness indexes. +- [ ] Decide whether items are unique by `(website_id, tryout_id, slot)` or `(website_id, tryout_id, slot, level)`. +- [ ] Align Excel import duplicate detection with the final uniqueness rule. +- [ ] Remove production `create_all` startup behavior or gate it to development only. +- [ ] Add migration smoke tests for fresh database upgrade to head. +- [ ] Add DB constraint tests for FK failures and uniqueness conflicts. +- [ ] Create seed/dev fixtures for websites, users, tryouts, items, and sessions. +- [ ] Document migration rollback expectations. + +## 5. API Reliability + +- [ ] Standardize error response shape across routers. +- [ ] Convert expected DB constraint failures into clear `400`, `404`, or `409` responses. +- [ ] Add request size limits for Excel and JSON imports. +- [ ] Add structured logging with request IDs. +- [ ] Add health checks that distinguish DB, Redis, WordPress, and OpenRouter status. +- [ ] Add OpenAPI examples for core workflows. +- [ ] Add pagination to list/report endpoints that can grow large. +- [ ] Add timeout and retry policy for external service calls. + +## 6. Import and Export + +- [ ] Validate website existence before Excel preview and import. +- [ ] Validate tryout existence before Excel question import. +- [ ] Add downloadable validation error reports. +- [ ] Add import preview diff for new records, skipped duplicates, and updates. +- [ ] Clean up generated export temp files after response lifecycle. +- [ ] Add tests for malformed Excel, duplicate slots, invalid p-values, invalid bobot values, and missing tryout. +- [x] Add tests for JSON snapshot import edge cases. +- [ ] Add file size/type hardening beyond extension checks. + +## 7. Reporting + +- [ ] Persist report schedules in the database instead of process memory. +- [ ] Add real scheduler/worker execution for scheduled reports. +- [ ] Add email delivery or remove recipient fields until delivery is implemented. +- [ ] Add report permission checks. +- [ ] Add tests for empty reports, partial data, and multi-tryout comparisons. +- [ ] Add pagination/export limits for large report datasets. +- [ ] Verify `avg_nn`, pass rate, medians, and standard deviations against fixture data. +- [ ] Add user-facing messages when report data is incomplete. + +## 8. Admin UI and UX + +- [ ] Add responsive mobile/tablet layout. +- [ ] Add active navigation state and breadcrumbs. +- [ ] Add pagination, sorting, and search to admin tables. +- [ ] Replace destructive browser confirms with safer confirmation modals. +- [ ] Add inline validation and success/error banners that persist after redirects. +- [ ] Add import progress indicators and clearer preview screens. +- [ ] Add empty states with recommended next actions. +- [ ] Improve visual hierarchy for dashboard stats and high-risk actions. +- [ ] Add accessibility pass: labels, focus states, contrast, keyboard navigation. + +## 9. Testing and Tooling + +- [ ] Add `pyproject.toml` or `pytest.ini` with test config. +- [ ] Add pinned dependency lock workflow. +- [ ] Add `make test`, `make lint`, `make migrate`, and `make dev` commands. +- [ ] Add CI for lint, tests, mapper config, Alembic upgrade, and import smoke tests. +- [ ] Add integration tests using a test database. +- [ ] Add auth boundary tests for every tenant-scoped endpoint. +- [x] Add regression tests for previously found defects. +- [ ] Document the canonical local setup path. + +## 10. Production Readiness + +- [ ] Validate required secrets in production startup. +- [ ] Document deployment environment variables. +- [ ] Add backup and restore guidance for PostgreSQL. +- [ ] Add observability: logs, metrics, traces, and error monitoring. +- [ ] Add operational runbooks for import failures, calibration failures, WordPress API outages, and AI provider outages. +- [ ] Add Redis availability checks when admin or background jobs are enabled. +- [ ] Add deployment checklist for migrations, admin credentials, CORS, HTTPS, and rollback. + +## Suggested Execution Order + +1. Security and auth hardening. +2. Session integrity and scoring correctness. +3. Database/migration alignment. +4. Test and tooling foundation. +5. Import/export and reporting reliability. +6. Admin UI/UX polish. +7. Production readiness and operations. + +## Definition of Perfect Enough + +- [ ] Every tenant-scoped endpoint has an authorization test. +- [ ] Every scoring path has deterministic fixture tests. +- [ ] Fresh database migration to head succeeds in CI. +- [ ] Admin destructive actions are CSRF-protected. +- [ ] Live sessions cannot reveal answers before completion. +- [ ] Imports fail safely with actionable validation output. +- [ ] Reports are reproducible, permissioned, and persisted where scheduled. +- [ ] The app can be installed, tested, migrated, and run from documented commands. diff --git a/tests/test_auth_scope.py b/tests/test_auth_scope.py new file mode 100644 index 0000000..0b268b7 --- /dev/null +++ b/tests/test_auth_scope.py @@ -0,0 +1,32 @@ +from pathlib import Path +import sys + +import pytest +from fastapi import HTTPException + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from app.core.auth import ( # noqa: E402 + AuthContext, + ensure_website_scope_matches, + require_website_auth, +) + + +def test_require_website_auth_returns_scoped_website_for_allowed_role(): + auth = AuthContext(website_id=5, role="admin", wp_user_id=None) + website_id = require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + assert website_id == 5 + + +def test_require_website_auth_rejects_disallowed_role(): + auth = AuthContext(website_id=5, role="student", wp_user_id="u1") + with pytest.raises(HTTPException) as exc_info: + require_website_auth(auth, allowed_roles={"admin", "system_admin"}) + assert exc_info.value.status_code == 403 + + +def test_cross_website_payload_mismatch_is_blocked(): + with pytest.raises(HTTPException) as exc_info: + ensure_website_scope_matches(auth_website_id=10, payload_website_id=11) + assert exc_info.value.status_code == 403 diff --git a/tests/test_auth_tokens.py b/tests/test_auth_tokens.py new file mode 100644 index 0000000..654ad53 --- /dev/null +++ b/tests/test_auth_tokens.py @@ -0,0 +1,50 @@ +from pathlib import Path +import sys +import time + +import pytest +from fastapi import HTTPException + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from app.core.auth import decode_access_token, issue_access_token # noqa: E402 + + +def test_issue_and_decode_access_token_round_trip(): + token = issue_access_token( + website_id=42, + role="student", + wp_user_id="wp-1001", + expires_in_seconds=3600, + ) + auth = decode_access_token(token) + assert auth.website_id == 42 + assert auth.role == "student" + assert auth.wp_user_id == "wp-1001" + + +def test_decode_access_token_rejects_tampered_signature(): + token = issue_access_token( + website_id=7, + role="admin", + wp_user_id=None, + expires_in_seconds=3600, + ) + payload, signature = token.split(".", 1) + tampered_token = f"{payload}.{'A' if signature[0] != 'A' else 'B'}{signature[1:]}" + with pytest.raises(HTTPException) as exc_info: + decode_access_token(tampered_token) + assert exc_info.value.status_code == 401 + + +def test_decode_access_token_rejects_expired_token(): + token = issue_access_token( + website_id=9, + role="student", + wp_user_id="u-1", + expires_in_seconds=-1, + ) + time.sleep(0.01) + with pytest.raises(HTTPException) as exc_info: + decode_access_token(token) + assert exc_info.value.status_code == 401