1125 lines
34 KiB
Python
1125 lines
34 KiB
Python
"""
|
||
IRT Calibration Service for Item Response Theory calculations.
|
||
|
||
Provides theta estimation, item calibration, and Fisher information calculations
|
||
for the 1PL (Rasch) IRT model.
|
||
"""
|
||
|
||
import math
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
from scipy.optimize import minimize_scalar
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models import Item, Session, UserAnswer
|
||
|
||
|
||
class IRTCalibrationError(Exception):
|
||
"""Exception raised for IRT calibration errors."""
|
||
pass
|
||
|
||
|
||
def calculate_fisher_information(theta: float, b: float) -> float:
|
||
"""
|
||
Calculate Fisher information for 1PL model at given theta.
|
||
|
||
I(θ) = P(θ) * (1 - P(θ))
|
||
where P(θ) = 1 / (1 + e^-(θ-b))
|
||
|
||
Args:
|
||
theta: Student ability estimate
|
||
b: Item difficulty parameter
|
||
|
||
Returns:
|
||
Fisher information value
|
||
"""
|
||
p = calculate_probability(theta, b)
|
||
return p * (1 - p)
|
||
|
||
|
||
def calculate_probability(theta: float, b: float) -> float:
|
||
"""
|
||
Calculate probability of correct response using 1PL Rasch model.
|
||
|
||
P(θ) = 1 / (1 + e^-(θ-b))
|
||
|
||
Args:
|
||
theta: Student ability estimate
|
||
b: Item difficulty parameter
|
||
|
||
Returns:
|
||
Probability of correct response [0, 1]
|
||
"""
|
||
exponent = theta - b
|
||
# Numerical stability: clip exponent
|
||
exponent = max(-30, min(30, exponent))
|
||
return 1.0 / (1.0 + math.exp(-exponent))
|
||
|
||
|
||
def estimate_theta_mle(
|
||
responses: list[int],
|
||
b_params: list[float],
|
||
initial_theta: float = 0.0
|
||
) -> tuple[float, float]:
|
||
"""
|
||
Estimate student ability theta using Maximum Likelihood Estimation.
|
||
|
||
Args:
|
||
responses: Binary responses [0, 1, 1, 0, ...]
|
||
b_params: Item difficulty parameters [b1, b2, b3, ...]
|
||
initial_theta: Initial theta guess (default 0.0)
|
||
|
||
Returns:
|
||
Tuple of (theta, standard_error)
|
||
|
||
Raises:
|
||
IRTCalibrationError: If inputs are invalid
|
||
"""
|
||
responses = np.asarray(responses, dtype=float)
|
||
b_params = np.asarray(b_params, dtype=float)
|
||
|
||
# Edge case: empty or mismatched inputs
|
||
if len(responses) == 0 or len(b_params) == 0:
|
||
return 0.0, 3.0 # Return default with high SE
|
||
|
||
if len(responses) != len(b_params):
|
||
raise IRTCalibrationError("responses and b_params must have same length")
|
||
|
||
n = len(responses)
|
||
sum_resp = np.sum(responses)
|
||
|
||
# Edge case: all correct - return high theta
|
||
if sum_resp == n:
|
||
return 3.0, 1.5 # Clamped to max theta
|
||
|
||
# Edge case: all incorrect - return low theta
|
||
if sum_resp == 0:
|
||
return -3.0, 1.5 # Clamped to min theta
|
||
|
||
def neg_log_likelihood(theta: float) -> float:
|
||
"""Negative log-likelihood for minimization."""
|
||
exponent = theta - b_params
|
||
# Numerical stability: clip exponent
|
||
exponent = np.clip(exponent, -30, 30)
|
||
p = 1.0 / (1.0 + np.exp(-exponent))
|
||
# Avoid log(0)
|
||
p = np.clip(p, 1e-10, 1 - 1e-10)
|
||
ll = np.sum(responses * np.log(p) + (1 - responses) * np.log(1 - p))
|
||
return -ll
|
||
|
||
result = minimize_scalar(neg_log_likelihood, bounds=(-3, 3), method='bounded')
|
||
|
||
if result.success:
|
||
theta = float(result.x)
|
||
else:
|
||
theta = initial_theta
|
||
|
||
# Calculate standard error using Fisher information
|
||
se = calculate_theta_se(theta, list(b_params))
|
||
|
||
# Clamp theta to valid range
|
||
theta = max(-3.0, min(3.0, theta))
|
||
|
||
return theta, se
|
||
|
||
|
||
def calculate_theta_se(theta: float, b_params: list[float]) -> float:
|
||
"""
|
||
Calculate standard error of theta estimate using Fisher information.
|
||
|
||
SE = 1 / sqrt(sum(I(θ)))
|
||
where I(θ) = P(θ) * (1 - P(θ)) for each item
|
||
|
||
Args:
|
||
theta: Current theta estimate
|
||
b_params: List of item difficulty parameters
|
||
|
||
Returns:
|
||
Standard error of theta estimate
|
||
"""
|
||
if not b_params:
|
||
return 3.0 # High uncertainty
|
||
|
||
total_info = 0.0
|
||
for b in b_params:
|
||
p = calculate_probability(theta, b)
|
||
info = p * (1 - p)
|
||
total_info += info
|
||
|
||
if total_info <= 0:
|
||
return 3.0 # High uncertainty
|
||
|
||
se = 1.0 / math.sqrt(total_info)
|
||
|
||
# Cap SE at reasonable maximum
|
||
return min(se, 3.0)
|
||
|
||
|
||
def estimate_b_from_ctt_p(ctt_p: float) -> float:
|
||
"""
|
||
Convert CTT difficulty (p-value) to IRT difficulty (b parameter).
|
||
|
||
Uses the approximation: b ≈ -ln((1-p)/p)
|
||
|
||
Args:
|
||
ctt_p: CTT difficulty (proportion correct) [0, 1]
|
||
|
||
Returns:
|
||
IRT difficulty parameter b [-3, +3]
|
||
"""
|
||
if ctt_p is None:
|
||
return 0.0
|
||
|
||
# Handle edge cases
|
||
if ctt_p >= 1.0:
|
||
return -3.0 # Very easy
|
||
if ctt_p <= 0.0:
|
||
return 3.0 # Very hard
|
||
|
||
# Clamp to avoid extreme values
|
||
ctt_p = max(0.01, min(0.99, ctt_p))
|
||
|
||
b = -math.log((1 - ctt_p) / ctt_p)
|
||
|
||
# Clamp to valid range
|
||
return max(-3.0, min(3.0, b))
|
||
|
||
|
||
async def get_session_responses(
|
||
db: AsyncSession,
|
||
session_id: str
|
||
) -> tuple[list[int], list[float]]:
|
||
"""
|
||
Get all responses and b-parameters for a session.
|
||
|
||
Args:
|
||
db: Database session
|
||
session_id: Session identifier
|
||
|
||
Returns:
|
||
Tuple of (responses, b_params)
|
||
"""
|
||
# Get all user answers for this session with item b parameters
|
||
query = (
|
||
select(UserAnswer, Item)
|
||
.join(Item, UserAnswer.item_id == Item.id)
|
||
.where(UserAnswer.session_id == session_id)
|
||
.order_by(UserAnswer.id)
|
||
)
|
||
|
||
result = await db.execute(query)
|
||
rows = result.all()
|
||
|
||
responses = []
|
||
b_params = []
|
||
|
||
for user_answer, item in rows:
|
||
responses.append(1 if user_answer.is_correct else 0)
|
||
# Use item's irt_b if calibrated, otherwise estimate from CTT p
|
||
if item.calibrated and item.irt_b is not None:
|
||
b_params.append(item.irt_b)
|
||
elif item.ctt_p is not None:
|
||
b_params.append(estimate_b_from_ctt_p(item.ctt_p))
|
||
else:
|
||
b_params.append(0.0) # Default difficulty
|
||
|
||
return responses, b_params
|
||
|
||
|
||
async def update_session_theta(
|
||
db: AsyncSession,
|
||
session_id: str,
|
||
force_recalculate: bool = False
|
||
) -> tuple[float, float]:
|
||
"""
|
||
Update session theta estimate based on all responses.
|
||
|
||
Args:
|
||
db: Database session
|
||
session_id: Session identifier
|
||
force_recalculate: Force recalculation even if theta exists
|
||
|
||
Returns:
|
||
Tuple of (theta, theta_se)
|
||
"""
|
||
# Get session
|
||
session_query = select(Session).where(Session.session_id == session_id)
|
||
session_result = await db.execute(session_query)
|
||
session = session_result.scalar_one_or_none()
|
||
|
||
if not session:
|
||
raise IRTCalibrationError(f"Session {session_id} not found")
|
||
|
||
# Get responses and b-parameters
|
||
responses, b_params = await get_session_responses(db, session_id)
|
||
|
||
if not responses:
|
||
# No responses yet, initialize theta
|
||
session.theta = 0.0
|
||
session.theta_se = 3.0
|
||
await db.commit()
|
||
return 0.0, 3.0
|
||
|
||
# Estimate theta
|
||
initial_theta = session.theta if session.theta is not None else 0.0
|
||
theta, se = estimate_theta_mle(responses, b_params, initial_theta)
|
||
|
||
# Update session
|
||
session.theta = theta
|
||
session.theta_se = se
|
||
await db.commit()
|
||
|
||
return theta, se
|
||
|
||
|
||
async def update_theta_after_response(
|
||
db: AsyncSession,
|
||
session_id: str,
|
||
item_id: int,
|
||
is_correct: bool
|
||
) -> tuple[float, float]:
|
||
"""
|
||
Update session theta after a single response.
|
||
|
||
This is an incremental update for real-time theta tracking.
|
||
|
||
Args:
|
||
db: Database session
|
||
session_id: Session identifier
|
||
item_id: Item that was answered
|
||
is_correct: Whether the answer was correct
|
||
|
||
Returns:
|
||
Tuple of (theta, theta_se)
|
||
"""
|
||
# Get session
|
||
session_query = select(Session).where(Session.session_id == session_id)
|
||
session_result = await db.execute(session_query)
|
||
session = session_result.scalar_one_or_none()
|
||
|
||
if not session:
|
||
raise IRTCalibrationError(f"Session {session_id} not found")
|
||
|
||
# Get item b parameter
|
||
item_query = select(Item).where(Item.id == item_id)
|
||
item_result = await db.execute(item_query)
|
||
item = item_result.scalar_one_or_none()
|
||
|
||
if not item:
|
||
raise IRTCalibrationError(f"Item {item_id} not found")
|
||
|
||
# Get b parameter
|
||
if item.calibrated and item.irt_b is not None:
|
||
b = item.irt_b
|
||
elif item.ctt_p is not None:
|
||
b = estimate_b_from_ctt_p(item.ctt_p)
|
||
else:
|
||
b = 0.0
|
||
|
||
# Get all responses including the new one
|
||
responses, b_params = await get_session_responses(db, session_id)
|
||
|
||
# Add current response if not already in list
|
||
responses.append(1 if is_correct else 0)
|
||
b_params.append(b)
|
||
|
||
# Estimate theta
|
||
initial_theta = session.theta if session.theta is not None else 0.0
|
||
theta, se = estimate_theta_mle(responses, b_params, initial_theta)
|
||
|
||
# Update session
|
||
session.theta = theta
|
||
session.theta_se = se
|
||
await db.commit()
|
||
|
||
return theta, se
|
||
|
||
|
||
def theta_to_nn(theta: float) -> int:
|
||
"""
|
||
Convert IRT theta to CTT-equivalent NN score.
|
||
|
||
Formula: NN = 500 + (θ / 3) × 500
|
||
|
||
Args:
|
||
theta: IRT ability estimate [-3, +3]
|
||
|
||
Returns:
|
||
NN score [0, 1000]
|
||
"""
|
||
# Clamp theta to valid range
|
||
theta = max(-3.0, min(3.0, theta))
|
||
|
||
nn = 500 + (theta / 3.0) * 500
|
||
|
||
# Clamp to valid range
|
||
return int(max(0, min(1000, nn)))
|
||
|
||
|
||
def nn_to_theta(nn: int) -> float:
|
||
"""
|
||
Convert CTT NN score to IRT theta.
|
||
|
||
Formula: θ = ((NN - 500) / 500) × 3
|
||
|
||
Args:
|
||
nn: NN score [0, 1000]
|
||
|
||
Returns:
|
||
IRT theta [-3, +3]
|
||
"""
|
||
# Clamp nn to valid range
|
||
nn = max(0, min(1000, nn))
|
||
|
||
theta = ((nn - 500) / 500.0) * 3.0
|
||
|
||
# Clamp to valid range
|
||
return max(-3.0, min(3.0, theta))
|
||
|
||
|
||
def calculate_item_information(theta: float, b: float) -> float:
|
||
"""
|
||
Calculate item information function at given theta.
|
||
|
||
For 1PL model, maximum information occurs when θ = b.
|
||
|
||
Args:
|
||
theta: Ability level
|
||
b: Item difficulty
|
||
|
||
Returns:
|
||
Item information value
|
||
"""
|
||
return calculate_fisher_information(theta, b)
|
||
|
||
|
||
# =============================================================================
|
||
# Joint MLE Calibration for b-parameters (EM-style iterative)
|
||
# =============================================================================
|
||
|
||
# Constants from PRD
|
||
THETA_MIN = -3.0
|
||
THETA_MAX = 3.0
|
||
B_MIN = -3.0
|
||
B_MAX = 3.0
|
||
CALIBRATION_SAMPLE_THRESHOLD = 500 # PRD requirement: 500+ responses for calibration
|
||
IRT_ROLLOUT_THRESHOLD = 0.90 # PRD requirement: 90% items calibrated for IRT rollout
|
||
SE_PRECISION_THRESHOLD = 0.5 # PRD requirement: SE < 0.5 after 15 items
|
||
MLE_BOUNDS = (-6.0, 6.0) # Optimization bounds (wider than final clamp)
|
||
EDGE_CASE_THETA_HIGH = 4.0 # All correct responses
|
||
EDGE_CASE_THETA_LOW = -4.0 # All incorrect responses
|
||
NUMERICAL_CLIP = 30 # Exponent clipping for numerical stability
|
||
|
||
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
from typing import Union
|
||
import logging
|
||
|
||
from sqlalchemy import func
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class CalibrationStatus(Enum):
|
||
"""Calibration status for items and tryouts."""
|
||
NOT_CALIBRATED = "not_calibrated"
|
||
INSUFFICIENT_DATA = "insufficient_data"
|
||
CONVERGED = "converged"
|
||
FAILED = "failed"
|
||
FALLBACK_CTT = "fallback_ctt"
|
||
|
||
|
||
@dataclass
|
||
class CalibrationResult:
|
||
"""Result of a single item calibration."""
|
||
item_id: int
|
||
status: CalibrationStatus
|
||
irt_b: Optional[float] = None
|
||
irt_se: Optional[float] = None
|
||
sample_size: int = 0
|
||
message: str = ""
|
||
|
||
@property
|
||
def is_calibrated(self) -> bool:
|
||
return self.status == CalibrationStatus.CONVERGED
|
||
|
||
|
||
@dataclass
|
||
class BatchCalibrationResult:
|
||
"""Result of batch calibration for a tryout."""
|
||
tryout_id: str
|
||
website_id: int
|
||
total_items: int
|
||
calibrated_items: int
|
||
failed_items: int
|
||
results: list[CalibrationResult]
|
||
ready_for_irt: bool
|
||
calibration_percentage: float
|
||
|
||
@property
|
||
def success_rate(self) -> float:
|
||
if self.total_items == 0:
|
||
return 0.0
|
||
return self.calibrated_items / self.total_items
|
||
|
||
|
||
def estimate_b(
|
||
responses_matrix: list[list[int]],
|
||
max_iterations: int = 20,
|
||
convergence_threshold: float = 0.001
|
||
) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||
"""
|
||
Estimate item difficulty parameters using joint MLE for 1PL IRT model.
|
||
|
||
Uses EM-style iterative algorithm:
|
||
1. Initialize theta = 0 for all students, b = 0 for all items
|
||
2. For each iteration:
|
||
- Update theta for each student given current b
|
||
- Update b for each item given current theta
|
||
3. Continue until convergence or max iterations
|
||
|
||
Parameters:
|
||
-----------
|
||
responses_matrix : list[list[int]]
|
||
Response matrix where rows=students, cols=items
|
||
entries are 0 or 1
|
||
max_iterations : int
|
||
Maximum EM iterations (default: 20)
|
||
convergence_threshold : float
|
||
Convergence threshold for b parameters (default: 0.001)
|
||
|
||
Returns:
|
||
--------
|
||
tuple[Optional[np.ndarray], Optional[np.ndarray]]
|
||
(b_parameters, se_parameters)
|
||
- b clamped to [-3, +3]
|
||
- SE calculated using Fisher information
|
||
"""
|
||
responses_matrix = np.asarray(responses_matrix, dtype=float)
|
||
|
||
# Edge case: empty matrix
|
||
if responses_matrix.size == 0:
|
||
return np.array([]), None
|
||
|
||
if responses_matrix.ndim != 2:
|
||
raise IRTCalibrationError("responses_matrix must be 2-dimensional")
|
||
|
||
n_students, n_items = responses_matrix.shape
|
||
|
||
if n_students == 0 or n_items == 0:
|
||
return np.zeros(n_items), None
|
||
|
||
# Initialize theta and b
|
||
theta = np.zeros(n_students)
|
||
b = np.zeros(n_items)
|
||
|
||
for iteration in range(max_iterations):
|
||
b_old = b.copy()
|
||
|
||
# Update theta for each student
|
||
for i in range(n_students):
|
||
resp_i = responses_matrix[i, :]
|
||
sum_resp = np.sum(resp_i)
|
||
|
||
if sum_resp == n_items:
|
||
theta[i] = EDGE_CASE_THETA_HIGH
|
||
elif sum_resp == 0:
|
||
theta[i] = EDGE_CASE_THETA_LOW
|
||
else:
|
||
def neg_ll_student(t: float) -> float:
|
||
exponent = np.clip(t - b, -NUMERICAL_CLIP, NUMERICAL_CLIP)
|
||
p = np.clip(1.0 / (1.0 + np.exp(-exponent)), 1e-10, 1 - 1e-10)
|
||
return -np.sum(resp_i * np.log(p) + (1 - resp_i) * np.log(1 - p))
|
||
|
||
res = minimize_scalar(neg_ll_student, bounds=MLE_BOUNDS, method='bounded')
|
||
theta[i] = res.x if res.success else 0.0
|
||
|
||
# Update b for each item
|
||
for j in range(n_items):
|
||
resp_j = responses_matrix[:, j]
|
||
sum_resp = np.sum(resp_j)
|
||
|
||
if sum_resp == n_students:
|
||
b[j] = -EDGE_CASE_THETA_HIGH # Easy item (everyone correct)
|
||
elif sum_resp == 0:
|
||
b[j] = EDGE_CASE_THETA_HIGH # Hard item (everyone incorrect)
|
||
else:
|
||
def neg_ll_item(bj: float) -> float:
|
||
exponent = np.clip(theta - bj, -NUMERICAL_CLIP, NUMERICAL_CLIP)
|
||
p = np.clip(1.0 / (1.0 + np.exp(-exponent)), 1e-10, 1 - 1e-10)
|
||
return -np.sum(resp_j * np.log(p) + (1 - resp_j) * np.log(1 - p))
|
||
|
||
res = minimize_scalar(neg_ll_item, bounds=MLE_BOUNDS, method='bounded')
|
||
b[j] = res.x if res.success else 0.0
|
||
|
||
# Check convergence
|
||
if np.max(np.abs(b - b_old)) < convergence_threshold:
|
||
logger.debug(f"Joint MLE converged at iteration {iteration + 1}")
|
||
break
|
||
|
||
# Clamp b to valid range
|
||
b = np.array([max(B_MIN, min(B_MAX, float(bj))) for bj in b])
|
||
|
||
# Calculate standard errors for b parameters
|
||
se = _calculate_b_se_batch(b, theta)
|
||
|
||
return b, se
|
||
|
||
|
||
def _calculate_b_se_batch(b_params: np.ndarray, thetas: np.ndarray) -> Optional[np.ndarray]:
|
||
"""
|
||
Calculate standard errors for all b parameters using Fisher information.
|
||
|
||
For 1PL model, Fisher information for item j at theta is:
|
||
I(b_j) = Σ P(θ_i) * (1 - P(θ_i))
|
||
And SE = 1 / sqrt(I(b_j))
|
||
|
||
Parameters:
|
||
-----------
|
||
b_params : np.ndarray
|
||
Item difficulty parameters
|
||
thetas : np.ndarray
|
||
Student ability estimates
|
||
|
||
Returns:
|
||
--------
|
||
Optional[np.ndarray]
|
||
Standard errors for each b parameter, or None if calculation fails
|
||
"""
|
||
try:
|
||
n_items = len(b_params)
|
||
se = np.zeros(n_items)
|
||
|
||
for j in range(n_items):
|
||
exponent = np.clip(thetas - b_params[j], -NUMERICAL_CLIP, NUMERICAL_CLIP)
|
||
p = 1.0 / (1.0 + np.exp(-exponent))
|
||
|
||
# Fisher information for item j
|
||
information = np.sum(p * (1 - p))
|
||
|
||
if information > 0:
|
||
se[j] = 1.0 / np.sqrt(information)
|
||
else:
|
||
se[j] = np.nan
|
||
|
||
return se
|
||
except Exception as e:
|
||
logger.warning(f"Failed to calculate b SE batch: {e}")
|
||
return None
|
||
|
||
|
||
async def calibrate_item(
|
||
item_id: int,
|
||
db: AsyncSession,
|
||
min_sample_size: int = CALIBRATION_SAMPLE_THRESHOLD
|
||
) -> CalibrationResult:
|
||
"""
|
||
Calibrate a single item using IRT 1PL model.
|
||
|
||
Fetches all UserAnswers for this item, builds response matrix,
|
||
estimates b-parameter using joint MLE, and updates the item.
|
||
|
||
Parameters:
|
||
-----------
|
||
item_id : int
|
||
Item ID to calibrate
|
||
db : AsyncSession
|
||
Database session
|
||
min_sample_size : int
|
||
Minimum sample size for calibration (default: 500)
|
||
|
||
Returns:
|
||
--------
|
||
CalibrationResult
|
||
Calibration result with status, b-parameter, SE, and sample size
|
||
"""
|
||
try:
|
||
# Fetch item
|
||
result = await db.execute(select(Item).where(Item.id == item_id))
|
||
item = result.scalar_one_or_none()
|
||
|
||
if not item:
|
||
return CalibrationResult(
|
||
item_id=item_id,
|
||
status=CalibrationStatus.FAILED,
|
||
message=f"Item {item_id} not found"
|
||
)
|
||
|
||
# Fetch all user answers for this item
|
||
result = await db.execute(
|
||
select(UserAnswer)
|
||
.where(UserAnswer.item_id == item_id)
|
||
.where(UserAnswer.is_correct.isnot(None))
|
||
)
|
||
answers = result.scalars().all()
|
||
|
||
sample_size = len(answers)
|
||
|
||
if sample_size < min_sample_size:
|
||
# Insufficient data - use CTT p-value for initial b estimate
|
||
if item.ctt_p is not None:
|
||
initial_b = estimate_b_from_ctt_p(item.ctt_p)
|
||
return CalibrationResult(
|
||
item_id=item_id,
|
||
status=CalibrationStatus.INSUFFICIENT_DATA,
|
||
irt_b=initial_b,
|
||
sample_size=sample_size,
|
||
message=f"Insufficient data ({sample_size} < {min_sample_size}). "
|
||
f"Using CTT-based initial estimate."
|
||
)
|
||
return CalibrationResult(
|
||
item_id=item_id,
|
||
status=CalibrationStatus.INSUFFICIENT_DATA,
|
||
sample_size=sample_size,
|
||
message=f"Insufficient data ({sample_size} < {min_sample_size})"
|
||
)
|
||
|
||
# Build response matrix
|
||
# Group answers by session to create student x item matrix
|
||
session_responses = {}
|
||
for answer in answers:
|
||
session_id = answer.session_id
|
||
if session_id not in session_responses:
|
||
session_responses[session_id] = {}
|
||
session_responses[session_id][item_id] = 1 if answer.is_correct else 0
|
||
|
||
# Get all items answered by these sessions for joint calibration
|
||
session_ids = list(session_responses.keys())
|
||
|
||
if len(session_ids) < 10:
|
||
return CalibrationResult(
|
||
item_id=item_id,
|
||
status=CalibrationStatus.INSUFFICIENT_DATA,
|
||
sample_size=sample_size,
|
||
message="Not enough unique sessions for calibration"
|
||
)
|
||
|
||
# Fetch all items answered by these sessions
|
||
result = await db.execute(
|
||
select(UserAnswer)
|
||
.where(UserAnswer.session_id.in_(session_ids))
|
||
.where(UserAnswer.is_correct.isnot(None))
|
||
)
|
||
all_answers = result.scalars().all()
|
||
|
||
# Build full response matrix (sessions x items)
|
||
item_ids = sorted(set(a.item_id for a in all_answers))
|
||
item_id_to_idx = {iid: idx for idx, iid in enumerate(item_ids)}
|
||
|
||
responses_matrix = []
|
||
for session_id in session_ids:
|
||
row = [0] * len(item_ids)
|
||
session_ans = [a for a in all_answers if a.session_id == session_id]
|
||
for ans in session_ans:
|
||
if ans.item_id in item_id_to_idx:
|
||
row[item_id_to_idx[ans.item_id]] = 1 if ans.is_correct else 0
|
||
responses_matrix.append(row)
|
||
|
||
# Run joint MLE calibration
|
||
b_params, se_params = estimate_b(responses_matrix)
|
||
|
||
if b_params is None or len(b_params) == 0:
|
||
return CalibrationResult(
|
||
item_id=item_id,
|
||
status=CalibrationStatus.FAILED,
|
||
sample_size=sample_size,
|
||
message="MLE estimation failed"
|
||
)
|
||
|
||
# Get b and SE for our target item
|
||
target_idx = item_id_to_idx.get(item_id)
|
||
if target_idx is None:
|
||
return CalibrationResult(
|
||
item_id=item_id,
|
||
status=CalibrationStatus.FAILED,
|
||
sample_size=sample_size,
|
||
message="Item not found in response matrix"
|
||
)
|
||
|
||
irt_b = float(b_params[target_idx])
|
||
irt_se = float(se_params[target_idx]) if se_params is not None else None
|
||
|
||
# Validate result
|
||
if not (B_MIN <= irt_b <= B_MAX):
|
||
logger.warning(f"b-parameter {irt_b} out of range for item {item_id}")
|
||
irt_b = max(B_MIN, min(B_MAX, irt_b))
|
||
|
||
# Update item in database
|
||
item.irt_b = irt_b
|
||
item.irt_se = irt_se
|
||
item.calibration_sample_size = sample_size
|
||
item.calibrated = sample_size >= min_sample_size
|
||
|
||
await db.commit()
|
||
|
||
return CalibrationResult(
|
||
item_id=item_id,
|
||
status=CalibrationStatus.CONVERGED,
|
||
irt_b=irt_b,
|
||
irt_se=irt_se,
|
||
sample_size=sample_size,
|
||
message=f"Successfully calibrated with {sample_size} responses"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Calibration failed for item {item_id}: {e}")
|
||
return CalibrationResult(
|
||
item_id=item_id,
|
||
status=CalibrationStatus.FAILED,
|
||
message=f"Calibration error: {str(e)}"
|
||
)
|
||
|
||
|
||
async def calibrate_all(
|
||
tryout_id: str,
|
||
website_id: int,
|
||
db: AsyncSession,
|
||
min_sample_size: int = CALIBRATION_SAMPLE_THRESHOLD
|
||
) -> BatchCalibrationResult:
|
||
"""
|
||
Calibrate all items in a tryout using IRT 1PL model.
|
||
|
||
Finds all uncalibrated items with sufficient responses,
|
||
runs calibration for each, and updates TryoutStats.
|
||
|
||
Parameters:
|
||
-----------
|
||
tryout_id : str
|
||
Tryout identifier
|
||
website_id : int
|
||
Website identifier
|
||
db : AsyncSession
|
||
Database session
|
||
min_sample_size : int
|
||
Minimum sample size for calibration (default: 500)
|
||
|
||
Returns:
|
||
--------
|
||
BatchCalibrationResult
|
||
Batch calibration result with status for each item
|
||
"""
|
||
results = []
|
||
|
||
try:
|
||
# Find all items for this tryout
|
||
result = await db.execute(
|
||
select(Item)
|
||
.where(Item.tryout_id == tryout_id)
|
||
.where(Item.website_id == website_id)
|
||
.order_by(Item.slot)
|
||
)
|
||
items = result.scalars().all()
|
||
|
||
total_items = len(items)
|
||
|
||
if total_items == 0:
|
||
return BatchCalibrationResult(
|
||
tryout_id=tryout_id,
|
||
website_id=website_id,
|
||
total_items=0,
|
||
calibrated_items=0,
|
||
failed_items=0,
|
||
results=[],
|
||
ready_for_irt=False,
|
||
calibration_percentage=0.0
|
||
)
|
||
|
||
# Get response counts per item
|
||
item_response_counts = {}
|
||
for item in items:
|
||
result = await db.execute(
|
||
select(func.count(UserAnswer.id))
|
||
.where(UserAnswer.item_id == item.id)
|
||
)
|
||
count = result.scalar() or 0
|
||
item_response_counts[item.id] = count
|
||
|
||
# Calibrate items with sufficient data
|
||
for item in items:
|
||
response_count = item_response_counts.get(item.id, 0)
|
||
|
||
if response_count >= min_sample_size and not item.calibrated:
|
||
cal_result = await calibrate_item(item.id, db, min_sample_size)
|
||
results.append(cal_result)
|
||
elif item.calibrated:
|
||
# Already calibrated
|
||
results.append(CalibrationResult(
|
||
item_id=item.id,
|
||
status=CalibrationStatus.CONVERGED,
|
||
irt_b=item.irt_b,
|
||
irt_se=item.irt_se,
|
||
sample_size=item.calibration_sample_size,
|
||
message="Already calibrated"
|
||
))
|
||
else:
|
||
# Insufficient data
|
||
results.append(CalibrationResult(
|
||
item_id=item.id,
|
||
status=CalibrationStatus.INSUFFICIENT_DATA,
|
||
sample_size=response_count,
|
||
message=f"Insufficient data ({response_count} < {min_sample_size})"
|
||
))
|
||
|
||
# Count results
|
||
calibrated_items = sum(1 for r in results if r.is_calibrated)
|
||
failed_items = sum(1 for r in results if r.status == CalibrationStatus.FAILED)
|
||
calibration_percentage = calibrated_items / total_items if total_items > 0 else 0.0
|
||
|
||
# Update TryoutStats if exists
|
||
try:
|
||
from app.models import TryoutStats
|
||
result = await db.execute(
|
||
select(TryoutStats)
|
||
.where(TryoutStats.tryout_id == tryout_id)
|
||
.where(TryoutStats.website_id == website_id)
|
||
)
|
||
stats = result.scalar_one_or_none()
|
||
|
||
if stats:
|
||
logger.info(
|
||
f"Tryout {tryout_id}: {calibrated_items}/{total_items} items calibrated "
|
||
f"({calibration_percentage:.1%})"
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Could not update TryoutStats: {e}")
|
||
|
||
ready_for_irt = calibration_percentage >= IRT_ROLLOUT_THRESHOLD
|
||
|
||
return BatchCalibrationResult(
|
||
tryout_id=tryout_id,
|
||
website_id=website_id,
|
||
total_items=total_items,
|
||
calibrated_items=calibrated_items,
|
||
failed_items=failed_items,
|
||
results=results,
|
||
ready_for_irt=ready_for_irt,
|
||
calibration_percentage=calibration_percentage
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Batch calibration failed for tryout {tryout_id}: {e}")
|
||
return BatchCalibrationResult(
|
||
tryout_id=tryout_id,
|
||
website_id=website_id,
|
||
total_items=len(results),
|
||
calibrated_items=sum(1 for r in results if r.is_calibrated),
|
||
failed_items=sum(1 for r in results if r.status == CalibrationStatus.FAILED),
|
||
results=results,
|
||
ready_for_irt=False,
|
||
calibration_percentage=0.0
|
||
)
|
||
|
||
|
||
def fallback_to_ctt(reason: str, context: Optional[dict] = None) -> dict:
|
||
"""
|
||
Generate fallback response for CTT mode when IRT fails.
|
||
|
||
Provides graceful degradation mechanism with logging and
|
||
recommendation for scoring mode.
|
||
|
||
Parameters:
|
||
-----------
|
||
reason : str
|
||
Reason for fallback (insufficient_data, convergence_error, etc.)
|
||
context : Optional[dict]
|
||
Additional context (item_id, tryout_id, etc.)
|
||
|
||
Returns:
|
||
--------
|
||
dict
|
||
Fallback response with:
|
||
- fallback_mode: "ctt"
|
||
- reason: str
|
||
- recommendation: str
|
||
- context: dict
|
||
"""
|
||
context = context or {}
|
||
|
||
recommendations = {
|
||
"insufficient_data": (
|
||
"Continue collecting response data. "
|
||
f"Need {CALIBRATION_SAMPLE_THRESHOLD}+ responses per item for IRT calibration. "
|
||
"Use CTT scoring until threshold is reached."
|
||
),
|
||
"convergence_error": (
|
||
"MLE optimization failed to converge. "
|
||
"Check for response patterns (all correct/incorrect). "
|
||
"Use CTT scoring as fallback."
|
||
),
|
||
"numerical_instability": (
|
||
"Numerical instability detected in MLE calculation. "
|
||
"Verify data quality and response patterns. "
|
||
"Use CTT scoring as fallback."
|
||
),
|
||
"missing_parameters": (
|
||
"Required IRT parameters not available. "
|
||
"Ensure items are calibrated before using IRT mode. "
|
||
"Use CTT scoring until calibration is complete."
|
||
),
|
||
"default": (
|
||
"IRT scoring unavailable. "
|
||
"Falling back to CTT scoring mode. "
|
||
"Check logs for details."
|
||
)
|
||
}
|
||
|
||
recommendation = recommendations.get(reason, recommendations["default"])
|
||
|
||
logger.warning(
|
||
f"IRT fallback to CTT - Reason: {reason}, Context: {context}"
|
||
)
|
||
|
||
return {
|
||
"fallback_mode": "ctt",
|
||
"reason": reason,
|
||
"recommendation": recommendation,
|
||
"context": context,
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
}
|
||
|
||
|
||
def validate_irt_parameters(
|
||
theta: Optional[float] = None,
|
||
b: Optional[float] = None,
|
||
se: Optional[float] = None
|
||
) -> tuple[bool, list[str]]:
|
||
"""
|
||
Validate IRT parameters against PRD constraints.
|
||
|
||
Parameters:
|
||
-----------
|
||
theta : Optional[float]
|
||
Ability estimate to validate
|
||
b : Optional[float]
|
||
Difficulty parameter to validate
|
||
se : Optional[float]
|
||
Standard error to validate
|
||
|
||
Returns:
|
||
--------
|
||
tuple[bool, list[str]]
|
||
(is_valid, list of error messages)
|
||
"""
|
||
errors = []
|
||
|
||
if theta is not None:
|
||
if not (THETA_MIN <= theta <= THETA_MAX):
|
||
errors.append(f"Theta {theta} out of range [{THETA_MIN}, {THETA_MAX}]")
|
||
|
||
if b is not None:
|
||
if not (B_MIN <= b <= B_MAX):
|
||
errors.append(f"b-parameter {b} out of range [{B_MIN}, {B_MAX}]")
|
||
|
||
if se is not None:
|
||
if se < 0:
|
||
errors.append(f"Standard error {se} must be non-negative")
|
||
elif se >= SE_PRECISION_THRESHOLD:
|
||
# Warning, not error - still valid but low precision
|
||
logger.warning(f"Standard error {se} exceeds precision threshold {SE_PRECISION_THRESHOLD}")
|
||
|
||
return len(errors) == 0, errors
|
||
|
||
|
||
async def get_calibration_status(
|
||
tryout_id: str,
|
||
website_id: int,
|
||
db: AsyncSession
|
||
) -> dict:
|
||
"""
|
||
Get calibration status for a tryout.
|
||
|
||
Parameters:
|
||
-----------
|
||
tryout_id : str
|
||
Tryout identifier
|
||
website_id : int
|
||
Website identifier
|
||
db : AsyncSession
|
||
Database session
|
||
|
||
Returns:
|
||
--------
|
||
dict
|
||
Calibration status including:
|
||
- total_items: int
|
||
- calibrated_items: int
|
||
- calibration_percentage: float
|
||
- ready_for_irt: bool
|
||
- items: list of item status
|
||
"""
|
||
result = await db.execute(
|
||
select(Item)
|
||
.where(Item.tryout_id == tryout_id)
|
||
.where(Item.website_id == website_id)
|
||
.order_by(Item.slot)
|
||
)
|
||
items = result.scalars().all()
|
||
|
||
total_items = len(items)
|
||
calibrated_items = sum(1 for item in items if item.calibrated)
|
||
calibration_percentage = calibrated_items / total_items if total_items > 0 else 0.0
|
||
ready_for_irt = calibration_percentage >= IRT_ROLLOUT_THRESHOLD
|
||
|
||
item_status = []
|
||
for item in items:
|
||
item_status.append({
|
||
"item_id": item.id,
|
||
"slot": item.slot,
|
||
"level": item.level,
|
||
"calibrated": item.calibrated,
|
||
"irt_b": item.irt_b,
|
||
"irt_se": item.irt_se,
|
||
"calibration_sample_size": item.calibration_sample_size
|
||
})
|
||
|
||
return {
|
||
"tryout_id": tryout_id,
|
||
"website_id": website_id,
|
||
"total_items": total_items,
|
||
"calibrated_items": calibrated_items,
|
||
"calibration_percentage": round(calibration_percentage * 100, 1),
|
||
"ready_for_irt": ready_for_irt,
|
||
"items": item_status
|
||
}
|
||
|
||
|
||
# Export public API
|
||
__all__ = [
|
||
# Constants
|
||
"THETA_MIN",
|
||
"THETA_MAX",
|
||
"B_MIN",
|
||
"B_MAX",
|
||
"CALIBRATION_SAMPLE_THRESHOLD",
|
||
"IRT_ROLLOUT_THRESHOLD",
|
||
"SE_PRECISION_THRESHOLD",
|
||
# Enums
|
||
"CalibrationStatus",
|
||
# Data classes
|
||
"CalibrationResult",
|
||
"BatchCalibrationResult",
|
||
# Exceptions
|
||
"IRTCalibrationError",
|
||
# Core functions
|
||
"estimate_theta_mle",
|
||
"estimate_b",
|
||
"calibrate_item",
|
||
"calibrate_all",
|
||
"fallback_to_ctt",
|
||
"validate_irt_parameters",
|
||
"get_calibration_status",
|
||
# Conversion functions
|
||
"estimate_b_from_ctt_p",
|
||
"theta_to_nn",
|
||
"nn_to_theta",
|
||
# Calculation functions
|
||
"calculate_probability",
|
||
"calculate_fisher_information",
|
||
"calculate_theta_se",
|
||
"calculate_item_information",
|
||
]
|