fix: harden admin access, repair ORM joins, and add migration/tests

This commit is contained in:
dwindown
2026-04-01 14:59:54 +07:00
parent de592d140e
commit 16ab13e911
21 changed files with 1275 additions and 368 deletions

View File

@@ -5,18 +5,29 @@ Provides admin panel for managing tryouts, items, sessions, users, and tryout st
Includes custom actions for calibration, AI generation toggle, and normalization reset.
"""
import secrets
import uuid
from dataclasses import dataclass
from typing import Any, Dict, Optional
from fastapi import Request
import aioredis
from fastapi import Depends, Form, HTTPException, Request
from fastapi_admin import constants
from fastapi_admin.app import app as admin_app
from fastapi_admin.depends import get_current_admin, get_resources
from fastapi_admin.providers import Provider
from fastapi_admin.resources import (
Field,
Link,
Model,
)
from fastapi_admin.template import templates
from fastapi_admin.widgets import displays, inputs
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import RedirectResponse
from starlette.status import HTTP_303_SEE_OTHER, HTTP_401_UNAUTHORIZED
from app.core.config import get_settings
from app.database import get_db
@@ -29,77 +40,175 @@ settings = get_settings()
# Authentication Provider
# =============================================================================
class AdminAuthProvider:
"""
Authentication provider for FastAPI Admin.
@dataclass
class AdminPrincipal:
"""Minimal admin user object expected by fastapi-admin templates."""
Supports two modes:
1. WordPress JWT token integration (production)
2. Basic auth for testing (development)
pk: str
username: str
avatar: str = ""
class EnvCredentialProvider(Provider):
"""
FastAPI-Admin provider backed by env credentials and Redis session tokens.
Compatible with fastapi-admin 1.0.x provider API without requiring
Tortoise admin models.
"""
async def login(
name = "env_credential_provider"
access_token = "access_token"
def __init__(
self,
username: str,
password: str,
) -> Optional[str]:
"""
Authenticate user and return token.
login_path: str = "/login",
logout_path: str = "/logout",
login_title: str = "Admin Login",
login_logo_url: str | None = None,
expire_seconds: int = 3600,
template: str = "providers/login/login.html",
) -> None:
self.username = username
self.password = password
self.login_path = login_path
self.logout_path = logout_path
self.login_title = login_title
self.login_logo_url = login_logo_url
self.expire_seconds = expire_seconds
self.template = template
Args:
username: Username
password: Password
async def register(self, app: "FastAPIAdmin") -> None:
await super().register(app)
app.get(self.login_path)(self.login_view)
app.post(self.login_path)(self.login)
app.get(self.logout_path)(self.logout)
app.get("/password")(self.password_view)
app.post("/password")(self.password)
app.add_middleware(BaseHTTPMiddleware, dispatch=self.authenticate)
Returns:
Access token if authentication successful, None otherwise
"""
# Development mode: basic auth
if settings.ENVIRONMENT == "development":
# Allow admin/admin or admin/password for testing
if (username == "admin" and password in ["admin", "password"]):
return f"dev_token_{username}"
async def login_view(self, request: Request):
return templates.TemplateResponse(
self.template,
context={
"request": request,
"login_logo_url": self.login_logo_url,
"login_title": self.login_title,
},
)
# Production mode: WordPress JWT token validation
# For now, return None - implement WordPress integration when needed
return None
async def login(
self,
request: Request,
username: str = Form(...),
password: str = Form(...),
remember_me: Optional[str] = Form(None),
):
if not (
secrets.compare_digest(username, self.username)
and secrets.compare_digest(password, self.password)
):
return templates.TemplateResponse(
self.template,
status_code=HTTP_401_UNAUTHORIZED,
context={
"request": request,
"error": "Invalid username or password",
"login_logo_url": self.login_logo_url,
"login_title": self.login_title,
},
)
async def logout(self, request: Request) -> bool:
"""
Logout user.
response = RedirectResponse(url=request.app.admin_path, status_code=HTTP_303_SEE_OTHER)
expire = self.expire_seconds
if remember_me == "on":
expire = max(self.expire_seconds, 3600 * 24 * 30)
response.set_cookie("remember_me", "on")
else:
response.delete_cookie("remember_me")
Args:
request: FastAPI request
token = uuid.uuid4().hex
response.set_cookie(
self.access_token,
token,
expires=expire,
path=request.app.admin_path,
httponly=True,
)
await request.app.redis.set(constants.LOGIN_USER.format(token=token), self.username, ex=expire)
return response
Returns:
True if logout successful
"""
return True
async def authenticate(self, request: Request, call_next: RequestResponseEndpoint):
token = request.cookies.get(self.access_token)
path = request.scope["path"]
admin = None
async def get_current_user(self, request: Request) -> Optional[dict]:
"""
Get current authenticated user.
if token:
key = constants.LOGIN_USER.format(token=token)
username = await request.app.redis.get(key)
if username:
admin = AdminPrincipal(pk=str(username), username=str(username))
Args:
request: FastAPI request
request.state.admin = admin
Returns:
User data if authenticated, None otherwise
"""
token = request.cookies.get("admin_token") or request.headers.get("Authorization")
if path.endswith(self.login_path) and admin:
return RedirectResponse(url=request.app.admin_path, status_code=HTTP_303_SEE_OTHER)
if not token:
return None
return await call_next(request)
# Development mode: validate dev token
if settings.ENVIRONMENT == "development" and token.startswith("dev_token_"):
username = token.replace("dev_token_", "")
return {
"id": 1,
"username": username,
"is_superuser": True,
}
async def logout(self, request: Request):
response = RedirectResponse(
url=request.app.admin_path + self.login_path,
status_code=HTTP_303_SEE_OTHER,
)
token = request.cookies.get(self.access_token)
if token:
await request.app.redis.delete(constants.LOGIN_USER.format(token=token))
response.delete_cookie(self.access_token, path=request.app.admin_path)
return response
return None
async def password_view(self, request: Request, resources=Depends(get_resources)):
return templates.TemplateResponse(
"providers/login/password.html",
context={"request": request, "resources": resources},
)
async def password(
self,
request: Request,
old_password: str = Form(...),
new_password: str = Form(...),
re_new_password: str = Form(...),
admin: AdminPrincipal = Depends(get_current_admin),
resources=Depends(get_resources),
):
_ = admin
if not secrets.compare_digest(old_password, self.password):
return templates.TemplateResponse(
"providers/login/password.html",
context={
"request": request,
"resources": resources,
"error": "Old password is incorrect",
},
)
if new_password != re_new_password:
return templates.TemplateResponse(
"providers/login/password.html",
context={
"request": request,
"resources": resources,
"error": "New passwords do not match",
},
)
# Password is env-configured and immutable at runtime.
raise HTTPException(
status_code=400,
detail="Password rotation via UI is disabled. Update ADMIN_PASSWORD in environment.",
)
# =============================================================================
@@ -604,7 +713,8 @@ def create_admin_app() -> Any:
# admin_app.settings.site_description = "Admin Panel for Adaptive Question Bank System"
# Register authentication provider
# admin_app.settings.auth_provider = AdminAuthProvider()
# NOTE: fastapi-admin 1.0.4 requires provider registration via app.configure(...).
# Keep provider implementation here for future integration during startup configure.
# Register model resources
admin_app.register(TryoutResource)
@@ -621,5 +731,55 @@ def create_admin_app() -> Any:
return admin_app
_admin_configured = False
_admin_redis = None
async def configure_admin_app() -> None:
"""Configure fastapi-admin runtime (redis + auth provider)."""
global _admin_configured, _admin_redis
if _admin_configured:
return
if not settings.ADMIN_USERNAME or not settings.ADMIN_PASSWORD:
raise RuntimeError(
"ENABLE_ADMIN=true requires ADMIN_USERNAME and ADMIN_PASSWORD to be set."
)
_admin_redis = aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
provider = EnvCredentialProvider(
username=settings.ADMIN_USERNAME,
password=settings.ADMIN_PASSWORD,
login_title="IRT Bank Soal Admin",
expire_seconds=settings.ADMIN_SESSION_EXPIRE_SECONDS,
)
await admin_app.configure(
redis=_admin_redis,
admin_path="/admin",
providers=[provider],
)
_admin_configured = True
async def shutdown_admin_app() -> None:
"""Close admin redis client cleanly."""
global _admin_redis
if _admin_redis is None:
return
try:
await _admin_redis.close()
finally:
_admin_redis = None
# Export admin app for mounting in main.py
admin = create_admin_app()

View File

@@ -35,6 +35,22 @@ class Settings(BaseSettings):
ENVIRONMENT: Literal["development", "staging", "production"] = Field(
default="development", description="Environment name"
)
ENABLE_ADMIN: bool = Field(
default=False,
description="Enable admin UI and admin-only API routes",
)
ADMIN_USERNAME: str = Field(
default="",
description="Admin panel username",
)
ADMIN_PASSWORD: str = Field(
default="",
description="Admin panel password (plain env value)",
)
ADMIN_SESSION_EXPIRE_SECONDS: int = Field(
default=3600,
description="Admin session lifetime in seconds",
)
# OpenRouter (AI Generation)
OPENROUTER_API_KEY: str = Field(

View File

@@ -16,7 +16,6 @@ from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.admin import admin as admin_app
from app.core.config import get_settings
from app.database import close_db, init_db
from app.routers import (
@@ -41,10 +40,18 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""
# Startup: Initialize database
await init_db()
if settings.ENABLE_ADMIN:
from app.admin import configure_admin_app
await configure_admin_app()
yield
# Shutdown: Close database connections
if settings.ENABLE_ADMIN:
from app.admin import shutdown_admin_app
await shutdown_admin_app()
await close_db()
@@ -162,25 +169,27 @@ app.include_router(
wordpress_router,
prefix=f"{settings.API_V1_STR}",
)
app.include_router(
ai_router,
prefix=f"{settings.API_V1_STR}",
)
app.include_router(
reports_router,
prefix=f"{settings.API_V1_STR}",
)
if settings.ENABLE_ADMIN:
from app.admin import admin as admin_app
# Mount FastAPI Admin panel
app.mount("/admin", admin_app)
app.include_router(
ai_router,
prefix=f"{settings.API_V1_STR}",
)
# Mount FastAPI Admin panel
app.mount("/admin", admin_app)
# Include admin API router for custom actions
app.include_router(
admin_router,
prefix=f"{settings.API_V1_STR}",
)
# Include admin API router for custom actions
app.include_router(
admin_router,
prefix=f"{settings.API_V1_STR}",
)
# Placeholder routers for future implementation

View File

@@ -14,11 +14,13 @@ from sqlalchemy import (
DateTime,
Float,
ForeignKey,
ForeignKeyConstraint,
Index,
Integer,
JSON,
String,
Text,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -156,13 +158,13 @@ class Item(Base):
# Timestamps
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default="NOW()"
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default="NOW()",
onupdate="NOW()",
server_default=func.now(),
onupdate=func.now(),
)
# Relationships
@@ -188,6 +190,13 @@ class Item(Base):
# Constraints and indexes
__table_args__ = (
ForeignKeyConstraint(
["website_id", "tryout_id"],
["tryouts.website_id", "tryouts.tryout_id"],
name="fk_items_tryout",
ondelete="CASCADE",
onupdate="CASCADE",
),
Index(
"ix_items_tryout_id_website_id_slot",
"tryout_id",

View File

@@ -13,9 +13,11 @@ from sqlalchemy import (
DateTime,
Float,
ForeignKey,
ForeignKeyConstraint,
Index,
Integer,
String,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -82,7 +84,7 @@ class Session(Base):
# Timestamps
start_time: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default="NOW()"
DateTime(timezone=True), nullable=False, server_default=func.now()
)
end_time: Mapped[Union[datetime, None]] = mapped_column(
DateTime(timezone=True), nullable=True, comment="Session end timestamp"
@@ -144,21 +146,27 @@ class Session(Base):
# Timestamps
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default="NOW()"
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default="NOW()",
onupdate="NOW()",
server_default=func.now(),
onupdate=func.now(),
)
# Relationships
user: Mapped["User"] = relationship(
"User", back_populates="sessions", lazy="selectin"
"User",
back_populates="sessions",
lazy="selectin",
overlaps="tryout,sessions",
)
tryout: Mapped["Tryout"] = relationship(
"Tryout", back_populates="sessions", lazy="selectin"
"Tryout",
back_populates="sessions",
lazy="selectin",
overlaps="user",
)
user_answers: Mapped[list["UserAnswer"]] = relationship(
"UserAnswer", back_populates="session", lazy="selectin", cascade="all, delete-orphan"
@@ -166,6 +174,20 @@ class Session(Base):
# Constraints and indexes
__table_args__ = (
ForeignKeyConstraint(
["website_id", "tryout_id"],
["tryouts.website_id", "tryouts.tryout_id"],
name="fk_sessions_tryout",
ondelete="CASCADE",
onupdate="CASCADE",
),
ForeignKeyConstraint(
["wp_user_id", "website_id"],
["users.wp_user_id", "users.website_id"],
name="fk_sessions_user",
ondelete="CASCADE",
onupdate="CASCADE",
),
Index("ix_sessions_wp_user_id", "wp_user_id"),
Index("ix_sessions_website_id", "website_id"),
Index("ix_sessions_tryout_id", "tryout_id"),

View File

@@ -7,7 +7,17 @@ Represents tryout exams with configurable scoring, selection, and normalization
from datetime import datetime
from typing import Literal, Union
from sqlalchemy import Boolean, CheckConstraint, DateTime, Float, ForeignKey, Index, Integer, String
from sqlalchemy import (
Boolean,
CheckConstraint,
DateTime,
Float,
ForeignKey,
Integer,
String,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
@@ -146,13 +156,13 @@ class Tryout(Base):
# Timestamps
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default="NOW()"
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default="NOW()",
onupdate="NOW()",
server_default=func.now(),
onupdate=func.now(),
)
# Relationships
@@ -163,7 +173,11 @@ class Tryout(Base):
"Item", back_populates="tryout", lazy="selectin", cascade="all, delete-orphan"
)
sessions: Mapped[list["Session"]] = relationship(
"Session", back_populates="tryout", lazy="selectin", cascade="all, delete-orphan"
"Session",
back_populates="tryout",
lazy="selectin",
cascade="all, delete-orphan",
overlaps="user",
)
stats: Mapped["TryoutStats"] = relationship(
"TryoutStats", back_populates="tryout", lazy="selectin", uselist=False
@@ -171,8 +185,10 @@ class Tryout(Base):
# Constraints and indexes
__table_args__ = (
Index(
"ix_tryouts_website_id_tryout_id", "website_id", "tryout_id", unique=True
UniqueConstraint(
"website_id",
"tryout_id",
name="uq_tryouts_website_id_tryout_id",
),
CheckConstraint("min_sample_for_dynamic > 0", "ck_min_sample_positive"),
CheckConstraint("static_rataan > 0", "ck_static_rataan_positive"),

View File

@@ -7,7 +7,17 @@ Maintains running statistics for dynamic normalization and reporting.
from datetime import datetime
from typing import Union
from sqlalchemy import CheckConstraint, DateTime, Float, ForeignKey, Index, Integer, String
from sqlalchemy import (
CheckConstraint,
DateTime,
Float,
ForeignKey,
ForeignKeyConstraint,
Index,
Integer,
String,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
@@ -107,13 +117,13 @@ class TryoutStats(Base):
comment="Timestamp of last statistics update",
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default="NOW()"
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default="NOW()",
onupdate="NOW()",
server_default=func.now(),
onupdate=func.now(),
)
# Relationships
@@ -123,6 +133,13 @@ class TryoutStats(Base):
# Constraints and indexes
__table_args__ = (
ForeignKeyConstraint(
["website_id", "tryout_id"],
["tryouts.website_id", "tryouts.tryout_id"],
name="fk_tryout_stats_tryout",
ondelete="CASCADE",
onupdate="CASCADE",
),
Index(
"ix_tryout_stats_website_id_tryout_id",
"website_id",

View File

@@ -6,7 +6,7 @@ Represents users from WordPress that can take tryouts.
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Index, String
from sqlalchemy import DateTime, ForeignKey, Index, String, UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
@@ -31,7 +31,7 @@ class User(Base):
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
# WordPress user ID (unique within website context)
wp_user_id: Mapped[int] = mapped_column(
wp_user_id: Mapped[str] = mapped_column(
String(255), nullable=False, index=True, comment="WordPress user ID"
)
@@ -44,13 +44,13 @@ class User(Base):
# Timestamps
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default="NOW()"
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default="NOW()",
onupdate="NOW()",
server_default=func.now(),
onupdate=func.now(),
)
# Relationships
@@ -58,12 +58,20 @@ class User(Base):
"Website", back_populates="users", lazy="selectin"
)
sessions: Mapped[list["Session"]] = relationship(
"Session", back_populates="user", lazy="selectin", cascade="all, delete-orphan"
"Session",
back_populates="user",
lazy="selectin",
cascade="all, delete-orphan",
overlaps="sessions,tryout",
)
# Indexes
__table_args__ = (
Index("ix_users_wp_user_id_website_id", "wp_user_id", "website_id", unique=True),
UniqueConstraint(
"wp_user_id",
"website_id",
name="uq_users_wp_user_id_website_id",
),
Index("ix_users_website_id", "website_id"),
)

View File

@@ -7,7 +7,7 @@ Represents a student's response to a single question with scoring metadata.
from datetime import datetime
from typing import Literal, Union
from sqlalchemy import Boolean, CheckConstraint, DateTime, Float, ForeignKey, Index, Integer, String
from sqlalchemy import Boolean, CheckConstraint, DateTime, Float, ForeignKey, Index, Integer, String, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
@@ -94,13 +94,13 @@ class UserAnswer(Base):
# Timestamps
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default="NOW()"
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default="NOW()",
onupdate="NOW()",
server_default=func.now(),
onupdate=func.now(),
)
# Relationships

View File

@@ -6,7 +6,7 @@ Represents WordPress websites that use the IRT Bank Soal system.
from datetime import datetime
from sqlalchemy import DateTime, String
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
@@ -48,13 +48,13 @@ class Website(Base):
# Timestamps
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default="NOW()"
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default="NOW()",
onupdate="NOW()",
server_default=func.now(),
onupdate=func.now(),
)
# Relationships

View File

@@ -341,6 +341,7 @@ async def get_session(
async def create_session(
request: SessionCreateRequest,
db: AsyncSession = Depends(get_db),
website_id: int = Depends(get_website_id_from_header),
) -> SessionResponse:
"""
Create a new session.
@@ -355,10 +356,19 @@ async def create_session(
Raises:
HTTPException: If tryout not found or session already exists
"""
if request.website_id != website_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"Website mismatch between payload and X-Website-ID header: "
f"{request.website_id} != {website_id}"
),
)
# Verify tryout exists
tryout_result = await db.execute(
select(Tryout).where(
Tryout.website_id == request.website_id,
Tryout.website_id == website_id,
Tryout.tryout_id == request.tryout_id,
)
)
@@ -367,7 +377,7 @@ async def create_session(
if tryout is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Tryout {request.tryout_id} not found for website {request.website_id}",
detail=f"Tryout {request.tryout_id} not found for website {website_id}",
)
# Check if session already exists
@@ -386,7 +396,7 @@ async def create_session(
session = Session(
session_id=request.session_id,
wp_user_id=request.wp_user_id,
website_id=request.website_id,
website_id=website_id,
tryout_id=request.tryout_id,
scoring_mode_used=request.scoring_mode,
start_time=datetime.now(timezone.utc),

View File

@@ -10,7 +10,7 @@ Endpoints:
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Header, status
from sqlalchemy import select, func
from sqlalchemy import Integer, cast, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@@ -292,7 +292,7 @@ async def get_calibration_status(
stats_result = await db.execute(
select(
func.count().label("total_items"),
func.sum(func.cast(Item.calibrated, type_=func.INTEGER)).label("calibrated_items"),
func.sum(cast(Item.calibrated, Integer)).label("calibrated_items"),
func.avg(Item.calibration_sample_size).label("avg_sample_size"),
).where(
Item.website_id == website_id,

View File

@@ -14,7 +14,7 @@ import math
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import func, select
from sqlalchemy import Integer, cast, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.item import Item
@@ -190,7 +190,7 @@ async def calculate_ctt_p_for_item(
result = await db.execute(
select(
func.count().label("total"),
func.sum(func.cast(UserAnswer.is_correct, type_=func.INTEGER)).label("correct"),
func.sum(cast(UserAnswer.is_correct, Integer)).label("correct"),
).where(UserAnswer.item_id == item_id)
)
row = result.first()

View File

@@ -308,7 +308,7 @@ async def get_normalization_params(
Tryout.tryout_id == tryout_id,
)
)
row = result.scalar_one_or_none()
row = result.one_or_none()
if row is None:
raise ValueError(
@@ -352,7 +352,7 @@ async def get_normalization_params(
Tryout.tryout_id == tryout_id,
)
)
row = result.scalar_one_or_none()
row = result.one_or_none()
if row is None:
raise ValueError(
f"Tryout {tryout_id} not found for website {website_id}"
@@ -369,7 +369,7 @@ async def get_normalization_params(
Tryout.tryout_id == tryout_id,
)
)
row = result.scalar_one_or_none()
row = result.one_or_none()
if row is None:
raise ValueError(
f"Tryout {tryout_id} not found for website {website_id}"

View File

@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
import logging
import pandas as pd
from sqlalchemy import select, func, and_, or_
from sqlalchemy import Integer, and_, cast, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@@ -415,7 +415,7 @@ async def generate_item_analysis_report(
resp_result = await db.execute(
select(
func.count().label("total"),
func.sum(func.cast(UserAnswer.is_correct, type_=func.INTEGER)).label("correct")
func.sum(cast(UserAnswer.is_correct, Integer)).label("correct")
).where(UserAnswer.item_id == item.id)
)
resp_stats = resp_result.first()
@@ -678,7 +678,7 @@ async def generate_tryout_comparison_report(
cal_result = await db.execute(
select(
func.count().label("total"),
func.sum(func.cast(Item.calibrated, type_=func.INTEGER)).label("calibrated")
func.sum(cast(Item.calibrated, Integer)).label("calibrated")
).where(
Item.tryout_id == tryout_id,
Item.website_id == website_id,
@@ -704,15 +704,56 @@ async def generate_tryout_comparison_report(
if tryout:
date_str = tryout.created_at.strftime("%Y-%m-%d")
session_result = await db.execute(
select(
func.count(Session.id).label("participant_count"),
func.avg(Session.NM).label("avg_nm"),
func.avg(Session.NN).label("avg_nn"),
func.avg(Session.theta).label("avg_theta"),
func.stddev_pop(Session.NM).label("std_nm"),
).where(
Session.tryout_id == tryout_id,
Session.website_id == website_id,
Session.is_completed.is_(True),
)
)
session_stats = session_result.first()
participant_count = (
int(session_stats.participant_count)
if session_stats and session_stats.participant_count
else (stats.participant_count if stats else 0)
)
avg_nm = (
round(float(session_stats.avg_nm), 2)
if session_stats and session_stats.avg_nm is not None
else (round(float(stats.rataan), 2) if stats and stats.rataan is not None else None)
)
avg_nn = (
round(float(session_stats.avg_nn), 2)
if session_stats and session_stats.avg_nn is not None
else None
)
avg_theta = (
round(float(session_stats.avg_theta), 4)
if session_stats and session_stats.avg_theta is not None
else None
)
std_nm = (
round(float(session_stats.std_nm), 2)
if session_stats and session_stats.std_nm is not None
else (round(float(stats.sb), 2) if stats and stats.sb is not None else None)
)
record = TryoutComparisonRecord(
tryout_id=tryout_id,
date=date_str,
subject=subject,
participant_count=stats.participant_count if stats else 0,
avg_nm=round(stats.rataan, 2) if stats and stats.rataan else None,
avg_nn=round(stats.rataan + 500, 2) if stats and stats.rataan else None,
avg_theta=None, # Would need to calculate from sessions
std_nm=round(stats.sb, 2) if stats and stats.sb else None,
participant_count=participant_count,
avg_nm=avg_nm,
avg_nn=avg_nn,
avg_theta=avg_theta,
std_nm=std_nm,
calibration_percentage=round(cal_percentage, 2),
)
comparison_records.append(record)