Files
yellow-bank-soal/app/services/irt_calibration.py
Dwindi Ramadhana cf193d7ea0 first commit
2026-03-21 23:32:59 +07:00

1125 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
IRT Calibration Service for Item Response Theory calculations.
Provides theta estimation, item calibration, and Fisher information calculations
for the 1PL (Rasch) IRT model.
"""
import math
from typing import Optional
import numpy as np
from scipy.optimize import minimize_scalar
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Item, Session, UserAnswer
class IRTCalibrationError(Exception):
"""Exception raised for IRT calibration errors."""
pass
def calculate_fisher_information(theta: float, b: float) -> float:
"""
Calculate Fisher information for 1PL model at given theta.
I(θ) = P(θ) * (1 - P(θ))
where P(θ) = 1 / (1 + e^-(θ-b))
Args:
theta: Student ability estimate
b: Item difficulty parameter
Returns:
Fisher information value
"""
p = calculate_probability(theta, b)
return p * (1 - p)
def calculate_probability(theta: float, b: float) -> float:
"""
Calculate probability of correct response using 1PL Rasch model.
P(θ) = 1 / (1 + e^-(θ-b))
Args:
theta: Student ability estimate
b: Item difficulty parameter
Returns:
Probability of correct response [0, 1]
"""
exponent = theta - b
# Numerical stability: clip exponent
exponent = max(-30, min(30, exponent))
return 1.0 / (1.0 + math.exp(-exponent))
def estimate_theta_mle(
responses: list[int],
b_params: list[float],
initial_theta: float = 0.0
) -> tuple[float, float]:
"""
Estimate student ability theta using Maximum Likelihood Estimation.
Args:
responses: Binary responses [0, 1, 1, 0, ...]
b_params: Item difficulty parameters [b1, b2, b3, ...]
initial_theta: Initial theta guess (default 0.0)
Returns:
Tuple of (theta, standard_error)
Raises:
IRTCalibrationError: If inputs are invalid
"""
responses = np.asarray(responses, dtype=float)
b_params = np.asarray(b_params, dtype=float)
# Edge case: empty or mismatched inputs
if len(responses) == 0 or len(b_params) == 0:
return 0.0, 3.0 # Return default with high SE
if len(responses) != len(b_params):
raise IRTCalibrationError("responses and b_params must have same length")
n = len(responses)
sum_resp = np.sum(responses)
# Edge case: all correct - return high theta
if sum_resp == n:
return 3.0, 1.5 # Clamped to max theta
# Edge case: all incorrect - return low theta
if sum_resp == 0:
return -3.0, 1.5 # Clamped to min theta
def neg_log_likelihood(theta: float) -> float:
"""Negative log-likelihood for minimization."""
exponent = theta - b_params
# Numerical stability: clip exponent
exponent = np.clip(exponent, -30, 30)
p = 1.0 / (1.0 + np.exp(-exponent))
# Avoid log(0)
p = np.clip(p, 1e-10, 1 - 1e-10)
ll = np.sum(responses * np.log(p) + (1 - responses) * np.log(1 - p))
return -ll
result = minimize_scalar(neg_log_likelihood, bounds=(-3, 3), method='bounded')
if result.success:
theta = float(result.x)
else:
theta = initial_theta
# Calculate standard error using Fisher information
se = calculate_theta_se(theta, list(b_params))
# Clamp theta to valid range
theta = max(-3.0, min(3.0, theta))
return theta, se
def calculate_theta_se(theta: float, b_params: list[float]) -> float:
"""
Calculate standard error of theta estimate using Fisher information.
SE = 1 / sqrt(sum(I(θ)))
where I(θ) = P(θ) * (1 - P(θ)) for each item
Args:
theta: Current theta estimate
b_params: List of item difficulty parameters
Returns:
Standard error of theta estimate
"""
if not b_params:
return 3.0 # High uncertainty
total_info = 0.0
for b in b_params:
p = calculate_probability(theta, b)
info = p * (1 - p)
total_info += info
if total_info <= 0:
return 3.0 # High uncertainty
se = 1.0 / math.sqrt(total_info)
# Cap SE at reasonable maximum
return min(se, 3.0)
def estimate_b_from_ctt_p(ctt_p: float) -> float:
"""
Convert CTT difficulty (p-value) to IRT difficulty (b parameter).
Uses the approximation: b ≈ -ln((1-p)/p)
Args:
ctt_p: CTT difficulty (proportion correct) [0, 1]
Returns:
IRT difficulty parameter b [-3, +3]
"""
if ctt_p is None:
return 0.0
# Handle edge cases
if ctt_p >= 1.0:
return -3.0 # Very easy
if ctt_p <= 0.0:
return 3.0 # Very hard
# Clamp to avoid extreme values
ctt_p = max(0.01, min(0.99, ctt_p))
b = -math.log((1 - ctt_p) / ctt_p)
# Clamp to valid range
return max(-3.0, min(3.0, b))
async def get_session_responses(
db: AsyncSession,
session_id: str
) -> tuple[list[int], list[float]]:
"""
Get all responses and b-parameters for a session.
Args:
db: Database session
session_id: Session identifier
Returns:
Tuple of (responses, b_params)
"""
# Get all user answers for this session with item b parameters
query = (
select(UserAnswer, Item)
.join(Item, UserAnswer.item_id == Item.id)
.where(UserAnswer.session_id == session_id)
.order_by(UserAnswer.id)
)
result = await db.execute(query)
rows = result.all()
responses = []
b_params = []
for user_answer, item in rows:
responses.append(1 if user_answer.is_correct else 0)
# Use item's irt_b if calibrated, otherwise estimate from CTT p
if item.calibrated and item.irt_b is not None:
b_params.append(item.irt_b)
elif item.ctt_p is not None:
b_params.append(estimate_b_from_ctt_p(item.ctt_p))
else:
b_params.append(0.0) # Default difficulty
return responses, b_params
async def update_session_theta(
db: AsyncSession,
session_id: str,
force_recalculate: bool = False
) -> tuple[float, float]:
"""
Update session theta estimate based on all responses.
Args:
db: Database session
session_id: Session identifier
force_recalculate: Force recalculation even if theta exists
Returns:
Tuple of (theta, theta_se)
"""
# Get session
session_query = select(Session).where(Session.session_id == session_id)
session_result = await db.execute(session_query)
session = session_result.scalar_one_or_none()
if not session:
raise IRTCalibrationError(f"Session {session_id} not found")
# Get responses and b-parameters
responses, b_params = await get_session_responses(db, session_id)
if not responses:
# No responses yet, initialize theta
session.theta = 0.0
session.theta_se = 3.0
await db.commit()
return 0.0, 3.0
# Estimate theta
initial_theta = session.theta if session.theta is not None else 0.0
theta, se = estimate_theta_mle(responses, b_params, initial_theta)
# Update session
session.theta = theta
session.theta_se = se
await db.commit()
return theta, se
async def update_theta_after_response(
db: AsyncSession,
session_id: str,
item_id: int,
is_correct: bool
) -> tuple[float, float]:
"""
Update session theta after a single response.
This is an incremental update for real-time theta tracking.
Args:
db: Database session
session_id: Session identifier
item_id: Item that was answered
is_correct: Whether the answer was correct
Returns:
Tuple of (theta, theta_se)
"""
# Get session
session_query = select(Session).where(Session.session_id == session_id)
session_result = await db.execute(session_query)
session = session_result.scalar_one_or_none()
if not session:
raise IRTCalibrationError(f"Session {session_id} not found")
# Get item b parameter
item_query = select(Item).where(Item.id == item_id)
item_result = await db.execute(item_query)
item = item_result.scalar_one_or_none()
if not item:
raise IRTCalibrationError(f"Item {item_id} not found")
# Get b parameter
if item.calibrated and item.irt_b is not None:
b = item.irt_b
elif item.ctt_p is not None:
b = estimate_b_from_ctt_p(item.ctt_p)
else:
b = 0.0
# Get all responses including the new one
responses, b_params = await get_session_responses(db, session_id)
# Add current response if not already in list
responses.append(1 if is_correct else 0)
b_params.append(b)
# Estimate theta
initial_theta = session.theta if session.theta is not None else 0.0
theta, se = estimate_theta_mle(responses, b_params, initial_theta)
# Update session
session.theta = theta
session.theta_se = se
await db.commit()
return theta, se
def theta_to_nn(theta: float) -> int:
"""
Convert IRT theta to CTT-equivalent NN score.
Formula: NN = 500 + (θ / 3) × 500
Args:
theta: IRT ability estimate [-3, +3]
Returns:
NN score [0, 1000]
"""
# Clamp theta to valid range
theta = max(-3.0, min(3.0, theta))
nn = 500 + (theta / 3.0) * 500
# Clamp to valid range
return int(max(0, min(1000, nn)))
def nn_to_theta(nn: int) -> float:
"""
Convert CTT NN score to IRT theta.
Formula: θ = ((NN - 500) / 500) × 3
Args:
nn: NN score [0, 1000]
Returns:
IRT theta [-3, +3]
"""
# Clamp nn to valid range
nn = max(0, min(1000, nn))
theta = ((nn - 500) / 500.0) * 3.0
# Clamp to valid range
return max(-3.0, min(3.0, theta))
def calculate_item_information(theta: float, b: float) -> float:
"""
Calculate item information function at given theta.
For 1PL model, maximum information occurs when θ = b.
Args:
theta: Ability level
b: Item difficulty
Returns:
Item information value
"""
return calculate_fisher_information(theta, b)
# =============================================================================
# Joint MLE Calibration for b-parameters (EM-style iterative)
# =============================================================================
# Constants from PRD
THETA_MIN = -3.0
THETA_MAX = 3.0
B_MIN = -3.0
B_MAX = 3.0
CALIBRATION_SAMPLE_THRESHOLD = 500 # PRD requirement: 500+ responses for calibration
IRT_ROLLOUT_THRESHOLD = 0.90 # PRD requirement: 90% items calibrated for IRT rollout
SE_PRECISION_THRESHOLD = 0.5 # PRD requirement: SE < 0.5 after 15 items
MLE_BOUNDS = (-6.0, 6.0) # Optimization bounds (wider than final clamp)
EDGE_CASE_THETA_HIGH = 4.0 # All correct responses
EDGE_CASE_THETA_LOW = -4.0 # All incorrect responses
NUMERICAL_CLIP = 30 # Exponent clipping for numerical stability
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Union
import logging
from sqlalchemy import func
logger = logging.getLogger(__name__)
class CalibrationStatus(Enum):
"""Calibration status for items and tryouts."""
NOT_CALIBRATED = "not_calibrated"
INSUFFICIENT_DATA = "insufficient_data"
CONVERGED = "converged"
FAILED = "failed"
FALLBACK_CTT = "fallback_ctt"
@dataclass
class CalibrationResult:
"""Result of a single item calibration."""
item_id: int
status: CalibrationStatus
irt_b: Optional[float] = None
irt_se: Optional[float] = None
sample_size: int = 0
message: str = ""
@property
def is_calibrated(self) -> bool:
return self.status == CalibrationStatus.CONVERGED
@dataclass
class BatchCalibrationResult:
"""Result of batch calibration for a tryout."""
tryout_id: str
website_id: int
total_items: int
calibrated_items: int
failed_items: int
results: list[CalibrationResult]
ready_for_irt: bool
calibration_percentage: float
@property
def success_rate(self) -> float:
if self.total_items == 0:
return 0.0
return self.calibrated_items / self.total_items
def estimate_b(
responses_matrix: list[list[int]],
max_iterations: int = 20,
convergence_threshold: float = 0.001
) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""
Estimate item difficulty parameters using joint MLE for 1PL IRT model.
Uses EM-style iterative algorithm:
1. Initialize theta = 0 for all students, b = 0 for all items
2. For each iteration:
- Update theta for each student given current b
- Update b for each item given current theta
3. Continue until convergence or max iterations
Parameters:
-----------
responses_matrix : list[list[int]]
Response matrix where rows=students, cols=items
entries are 0 or 1
max_iterations : int
Maximum EM iterations (default: 20)
convergence_threshold : float
Convergence threshold for b parameters (default: 0.001)
Returns:
--------
tuple[Optional[np.ndarray], Optional[np.ndarray]]
(b_parameters, se_parameters)
- b clamped to [-3, +3]
- SE calculated using Fisher information
"""
responses_matrix = np.asarray(responses_matrix, dtype=float)
# Edge case: empty matrix
if responses_matrix.size == 0:
return np.array([]), None
if responses_matrix.ndim != 2:
raise IRTCalibrationError("responses_matrix must be 2-dimensional")
n_students, n_items = responses_matrix.shape
if n_students == 0 or n_items == 0:
return np.zeros(n_items), None
# Initialize theta and b
theta = np.zeros(n_students)
b = np.zeros(n_items)
for iteration in range(max_iterations):
b_old = b.copy()
# Update theta for each student
for i in range(n_students):
resp_i = responses_matrix[i, :]
sum_resp = np.sum(resp_i)
if sum_resp == n_items:
theta[i] = EDGE_CASE_THETA_HIGH
elif sum_resp == 0:
theta[i] = EDGE_CASE_THETA_LOW
else:
def neg_ll_student(t: float) -> float:
exponent = np.clip(t - b, -NUMERICAL_CLIP, NUMERICAL_CLIP)
p = np.clip(1.0 / (1.0 + np.exp(-exponent)), 1e-10, 1 - 1e-10)
return -np.sum(resp_i * np.log(p) + (1 - resp_i) * np.log(1 - p))
res = minimize_scalar(neg_ll_student, bounds=MLE_BOUNDS, method='bounded')
theta[i] = res.x if res.success else 0.0
# Update b for each item
for j in range(n_items):
resp_j = responses_matrix[:, j]
sum_resp = np.sum(resp_j)
if sum_resp == n_students:
b[j] = -EDGE_CASE_THETA_HIGH # Easy item (everyone correct)
elif sum_resp == 0:
b[j] = EDGE_CASE_THETA_HIGH # Hard item (everyone incorrect)
else:
def neg_ll_item(bj: float) -> float:
exponent = np.clip(theta - bj, -NUMERICAL_CLIP, NUMERICAL_CLIP)
p = np.clip(1.0 / (1.0 + np.exp(-exponent)), 1e-10, 1 - 1e-10)
return -np.sum(resp_j * np.log(p) + (1 - resp_j) * np.log(1 - p))
res = minimize_scalar(neg_ll_item, bounds=MLE_BOUNDS, method='bounded')
b[j] = res.x if res.success else 0.0
# Check convergence
if np.max(np.abs(b - b_old)) < convergence_threshold:
logger.debug(f"Joint MLE converged at iteration {iteration + 1}")
break
# Clamp b to valid range
b = np.array([max(B_MIN, min(B_MAX, float(bj))) for bj in b])
# Calculate standard errors for b parameters
se = _calculate_b_se_batch(b, theta)
return b, se
def _calculate_b_se_batch(b_params: np.ndarray, thetas: np.ndarray) -> Optional[np.ndarray]:
"""
Calculate standard errors for all b parameters using Fisher information.
For 1PL model, Fisher information for item j at theta is:
I(b_j) = Σ P(θ_i) * (1 - P(θ_i))
And SE = 1 / sqrt(I(b_j))
Parameters:
-----------
b_params : np.ndarray
Item difficulty parameters
thetas : np.ndarray
Student ability estimates
Returns:
--------
Optional[np.ndarray]
Standard errors for each b parameter, or None if calculation fails
"""
try:
n_items = len(b_params)
se = np.zeros(n_items)
for j in range(n_items):
exponent = np.clip(thetas - b_params[j], -NUMERICAL_CLIP, NUMERICAL_CLIP)
p = 1.0 / (1.0 + np.exp(-exponent))
# Fisher information for item j
information = np.sum(p * (1 - p))
if information > 0:
se[j] = 1.0 / np.sqrt(information)
else:
se[j] = np.nan
return se
except Exception as e:
logger.warning(f"Failed to calculate b SE batch: {e}")
return None
async def calibrate_item(
item_id: int,
db: AsyncSession,
min_sample_size: int = CALIBRATION_SAMPLE_THRESHOLD
) -> CalibrationResult:
"""
Calibrate a single item using IRT 1PL model.
Fetches all UserAnswers for this item, builds response matrix,
estimates b-parameter using joint MLE, and updates the item.
Parameters:
-----------
item_id : int
Item ID to calibrate
db : AsyncSession
Database session
min_sample_size : int
Minimum sample size for calibration (default: 500)
Returns:
--------
CalibrationResult
Calibration result with status, b-parameter, SE, and sample size
"""
try:
# Fetch item
result = await db.execute(select(Item).where(Item.id == item_id))
item = result.scalar_one_or_none()
if not item:
return CalibrationResult(
item_id=item_id,
status=CalibrationStatus.FAILED,
message=f"Item {item_id} not found"
)
# Fetch all user answers for this item
result = await db.execute(
select(UserAnswer)
.where(UserAnswer.item_id == item_id)
.where(UserAnswer.is_correct.isnot(None))
)
answers = result.scalars().all()
sample_size = len(answers)
if sample_size < min_sample_size:
# Insufficient data - use CTT p-value for initial b estimate
if item.ctt_p is not None:
initial_b = estimate_b_from_ctt_p(item.ctt_p)
return CalibrationResult(
item_id=item_id,
status=CalibrationStatus.INSUFFICIENT_DATA,
irt_b=initial_b,
sample_size=sample_size,
message=f"Insufficient data ({sample_size} < {min_sample_size}). "
f"Using CTT-based initial estimate."
)
return CalibrationResult(
item_id=item_id,
status=CalibrationStatus.INSUFFICIENT_DATA,
sample_size=sample_size,
message=f"Insufficient data ({sample_size} < {min_sample_size})"
)
# Build response matrix
# Group answers by session to create student x item matrix
session_responses = {}
for answer in answers:
session_id = answer.session_id
if session_id not in session_responses:
session_responses[session_id] = {}
session_responses[session_id][item_id] = 1 if answer.is_correct else 0
# Get all items answered by these sessions for joint calibration
session_ids = list(session_responses.keys())
if len(session_ids) < 10:
return CalibrationResult(
item_id=item_id,
status=CalibrationStatus.INSUFFICIENT_DATA,
sample_size=sample_size,
message="Not enough unique sessions for calibration"
)
# Fetch all items answered by these sessions
result = await db.execute(
select(UserAnswer)
.where(UserAnswer.session_id.in_(session_ids))
.where(UserAnswer.is_correct.isnot(None))
)
all_answers = result.scalars().all()
# Build full response matrix (sessions x items)
item_ids = sorted(set(a.item_id for a in all_answers))
item_id_to_idx = {iid: idx for idx, iid in enumerate(item_ids)}
responses_matrix = []
for session_id in session_ids:
row = [0] * len(item_ids)
session_ans = [a for a in all_answers if a.session_id == session_id]
for ans in session_ans:
if ans.item_id in item_id_to_idx:
row[item_id_to_idx[ans.item_id]] = 1 if ans.is_correct else 0
responses_matrix.append(row)
# Run joint MLE calibration
b_params, se_params = estimate_b(responses_matrix)
if b_params is None or len(b_params) == 0:
return CalibrationResult(
item_id=item_id,
status=CalibrationStatus.FAILED,
sample_size=sample_size,
message="MLE estimation failed"
)
# Get b and SE for our target item
target_idx = item_id_to_idx.get(item_id)
if target_idx is None:
return CalibrationResult(
item_id=item_id,
status=CalibrationStatus.FAILED,
sample_size=sample_size,
message="Item not found in response matrix"
)
irt_b = float(b_params[target_idx])
irt_se = float(se_params[target_idx]) if se_params is not None else None
# Validate result
if not (B_MIN <= irt_b <= B_MAX):
logger.warning(f"b-parameter {irt_b} out of range for item {item_id}")
irt_b = max(B_MIN, min(B_MAX, irt_b))
# Update item in database
item.irt_b = irt_b
item.irt_se = irt_se
item.calibration_sample_size = sample_size
item.calibrated = sample_size >= min_sample_size
await db.commit()
return CalibrationResult(
item_id=item_id,
status=CalibrationStatus.CONVERGED,
irt_b=irt_b,
irt_se=irt_se,
sample_size=sample_size,
message=f"Successfully calibrated with {sample_size} responses"
)
except Exception as e:
logger.error(f"Calibration failed for item {item_id}: {e}")
return CalibrationResult(
item_id=item_id,
status=CalibrationStatus.FAILED,
message=f"Calibration error: {str(e)}"
)
async def calibrate_all(
tryout_id: str,
website_id: int,
db: AsyncSession,
min_sample_size: int = CALIBRATION_SAMPLE_THRESHOLD
) -> BatchCalibrationResult:
"""
Calibrate all items in a tryout using IRT 1PL model.
Finds all uncalibrated items with sufficient responses,
runs calibration for each, and updates TryoutStats.
Parameters:
-----------
tryout_id : str
Tryout identifier
website_id : int
Website identifier
db : AsyncSession
Database session
min_sample_size : int
Minimum sample size for calibration (default: 500)
Returns:
--------
BatchCalibrationResult
Batch calibration result with status for each item
"""
results = []
try:
# Find all items for this tryout
result = await db.execute(
select(Item)
.where(Item.tryout_id == tryout_id)
.where(Item.website_id == website_id)
.order_by(Item.slot)
)
items = result.scalars().all()
total_items = len(items)
if total_items == 0:
return BatchCalibrationResult(
tryout_id=tryout_id,
website_id=website_id,
total_items=0,
calibrated_items=0,
failed_items=0,
results=[],
ready_for_irt=False,
calibration_percentage=0.0
)
# Get response counts per item
item_response_counts = {}
for item in items:
result = await db.execute(
select(func.count(UserAnswer.id))
.where(UserAnswer.item_id == item.id)
)
count = result.scalar() or 0
item_response_counts[item.id] = count
# Calibrate items with sufficient data
for item in items:
response_count = item_response_counts.get(item.id, 0)
if response_count >= min_sample_size and not item.calibrated:
cal_result = await calibrate_item(item.id, db, min_sample_size)
results.append(cal_result)
elif item.calibrated:
# Already calibrated
results.append(CalibrationResult(
item_id=item.id,
status=CalibrationStatus.CONVERGED,
irt_b=item.irt_b,
irt_se=item.irt_se,
sample_size=item.calibration_sample_size,
message="Already calibrated"
))
else:
# Insufficient data
results.append(CalibrationResult(
item_id=item.id,
status=CalibrationStatus.INSUFFICIENT_DATA,
sample_size=response_count,
message=f"Insufficient data ({response_count} < {min_sample_size})"
))
# Count results
calibrated_items = sum(1 for r in results if r.is_calibrated)
failed_items = sum(1 for r in results if r.status == CalibrationStatus.FAILED)
calibration_percentage = calibrated_items / total_items if total_items > 0 else 0.0
# Update TryoutStats if exists
try:
from app.models import TryoutStats
result = await db.execute(
select(TryoutStats)
.where(TryoutStats.tryout_id == tryout_id)
.where(TryoutStats.website_id == website_id)
)
stats = result.scalar_one_or_none()
if stats:
logger.info(
f"Tryout {tryout_id}: {calibrated_items}/{total_items} items calibrated "
f"({calibration_percentage:.1%})"
)
except Exception as e:
logger.warning(f"Could not update TryoutStats: {e}")
ready_for_irt = calibration_percentage >= IRT_ROLLOUT_THRESHOLD
return BatchCalibrationResult(
tryout_id=tryout_id,
website_id=website_id,
total_items=total_items,
calibrated_items=calibrated_items,
failed_items=failed_items,
results=results,
ready_for_irt=ready_for_irt,
calibration_percentage=calibration_percentage
)
except Exception as e:
logger.error(f"Batch calibration failed for tryout {tryout_id}: {e}")
return BatchCalibrationResult(
tryout_id=tryout_id,
website_id=website_id,
total_items=len(results),
calibrated_items=sum(1 for r in results if r.is_calibrated),
failed_items=sum(1 for r in results if r.status == CalibrationStatus.FAILED),
results=results,
ready_for_irt=False,
calibration_percentage=0.0
)
def fallback_to_ctt(reason: str, context: Optional[dict] = None) -> dict:
"""
Generate fallback response for CTT mode when IRT fails.
Provides graceful degradation mechanism with logging and
recommendation for scoring mode.
Parameters:
-----------
reason : str
Reason for fallback (insufficient_data, convergence_error, etc.)
context : Optional[dict]
Additional context (item_id, tryout_id, etc.)
Returns:
--------
dict
Fallback response with:
- fallback_mode: "ctt"
- reason: str
- recommendation: str
- context: dict
"""
context = context or {}
recommendations = {
"insufficient_data": (
"Continue collecting response data. "
f"Need {CALIBRATION_SAMPLE_THRESHOLD}+ responses per item for IRT calibration. "
"Use CTT scoring until threshold is reached."
),
"convergence_error": (
"MLE optimization failed to converge. "
"Check for response patterns (all correct/incorrect). "
"Use CTT scoring as fallback."
),
"numerical_instability": (
"Numerical instability detected in MLE calculation. "
"Verify data quality and response patterns. "
"Use CTT scoring as fallback."
),
"missing_parameters": (
"Required IRT parameters not available. "
"Ensure items are calibrated before using IRT mode. "
"Use CTT scoring until calibration is complete."
),
"default": (
"IRT scoring unavailable. "
"Falling back to CTT scoring mode. "
"Check logs for details."
)
}
recommendation = recommendations.get(reason, recommendations["default"])
logger.warning(
f"IRT fallback to CTT - Reason: {reason}, Context: {context}"
)
return {
"fallback_mode": "ctt",
"reason": reason,
"recommendation": recommendation,
"context": context,
"timestamp": datetime.utcnow().isoformat()
}
def validate_irt_parameters(
theta: Optional[float] = None,
b: Optional[float] = None,
se: Optional[float] = None
) -> tuple[bool, list[str]]:
"""
Validate IRT parameters against PRD constraints.
Parameters:
-----------
theta : Optional[float]
Ability estimate to validate
b : Optional[float]
Difficulty parameter to validate
se : Optional[float]
Standard error to validate
Returns:
--------
tuple[bool, list[str]]
(is_valid, list of error messages)
"""
errors = []
if theta is not None:
if not (THETA_MIN <= theta <= THETA_MAX):
errors.append(f"Theta {theta} out of range [{THETA_MIN}, {THETA_MAX}]")
if b is not None:
if not (B_MIN <= b <= B_MAX):
errors.append(f"b-parameter {b} out of range [{B_MIN}, {B_MAX}]")
if se is not None:
if se < 0:
errors.append(f"Standard error {se} must be non-negative")
elif se >= SE_PRECISION_THRESHOLD:
# Warning, not error - still valid but low precision
logger.warning(f"Standard error {se} exceeds precision threshold {SE_PRECISION_THRESHOLD}")
return len(errors) == 0, errors
async def get_calibration_status(
tryout_id: str,
website_id: int,
db: AsyncSession
) -> dict:
"""
Get calibration status for a tryout.
Parameters:
-----------
tryout_id : str
Tryout identifier
website_id : int
Website identifier
db : AsyncSession
Database session
Returns:
--------
dict
Calibration status including:
- total_items: int
- calibrated_items: int
- calibration_percentage: float
- ready_for_irt: bool
- items: list of item status
"""
result = await db.execute(
select(Item)
.where(Item.tryout_id == tryout_id)
.where(Item.website_id == website_id)
.order_by(Item.slot)
)
items = result.scalars().all()
total_items = len(items)
calibrated_items = sum(1 for item in items if item.calibrated)
calibration_percentage = calibrated_items / total_items if total_items > 0 else 0.0
ready_for_irt = calibration_percentage >= IRT_ROLLOUT_THRESHOLD
item_status = []
for item in items:
item_status.append({
"item_id": item.id,
"slot": item.slot,
"level": item.level,
"calibrated": item.calibrated,
"irt_b": item.irt_b,
"irt_se": item.irt_se,
"calibration_sample_size": item.calibration_sample_size
})
return {
"tryout_id": tryout_id,
"website_id": website_id,
"total_items": total_items,
"calibrated_items": calibrated_items,
"calibration_percentage": round(calibration_percentage * 100, 1),
"ready_for_irt": ready_for_irt,
"items": item_status
}
# Export public API
__all__ = [
# Constants
"THETA_MIN",
"THETA_MAX",
"B_MIN",
"B_MAX",
"CALIBRATION_SAMPLE_THRESHOLD",
"IRT_ROLLOUT_THRESHOLD",
"SE_PRECISION_THRESHOLD",
# Enums
"CalibrationStatus",
# Data classes
"CalibrationResult",
"BatchCalibrationResult",
# Exceptions
"IRTCalibrationError",
# Core functions
"estimate_theta_mle",
"estimate_b",
"calibrate_item",
"calibrate_all",
"fallback_to_ctt",
"validate_irt_parameters",
"get_calibration_status",
# Conversion functions
"estimate_b_from_ctt_p",
"theta_to_nn",
"nn_to_theta",
# Calculation functions
"calculate_probability",
"calculate_fisher_information",
"calculate_theta_se",
"calculate_item_information",
]