272 lines
6.9 KiB
Python
272 lines
6.9 KiB
Python
import asyncio
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
|
|
from app.core import rate_limit
|
|
from app.core.config import Settings
|
|
from app.models.report_schedule import ReportScheduleModel
|
|
from app.services import ai_generation
|
|
from app.services import cat_selection
|
|
from app.services.reporting import (
|
|
cancel_scheduled_report,
|
|
get_scheduled_report,
|
|
list_scheduled_reports,
|
|
schedule_report,
|
|
)
|
|
from app.schemas.ai import GeneratedQuestion
|
|
|
|
|
|
class DummyRequest:
|
|
client = SimpleNamespace(host="127.0.0.1")
|
|
|
|
|
|
class DummyScalarResult:
|
|
def __init__(self, value):
|
|
self._value = value
|
|
|
|
def scalar_one_or_none(self):
|
|
return self._value
|
|
|
|
def scalar(self):
|
|
return self._value
|
|
|
|
|
|
class DummyScalars:
|
|
def __init__(self, values):
|
|
self._values = values
|
|
|
|
def all(self):
|
|
return self._values
|
|
|
|
|
|
class DummyListResult:
|
|
def __init__(self, values):
|
|
self._values = values
|
|
|
|
def scalars(self):
|
|
return DummyScalars(self._values)
|
|
|
|
|
|
class DummyRowsResult:
|
|
def __init__(self, values):
|
|
self._values = values
|
|
|
|
def all(self):
|
|
return self._values
|
|
|
|
|
|
class DummyDb:
|
|
def __init__(self, execute_results=None):
|
|
self.execute_results = list(execute_results or [])
|
|
self.added = []
|
|
self.flushed = False
|
|
|
|
def add(self, row):
|
|
self.added.append(row)
|
|
|
|
async def flush(self):
|
|
self.flushed = True
|
|
|
|
async def execute(self, _query):
|
|
return self.execute_results.pop(0)
|
|
|
|
|
|
class DummyRedis:
|
|
def __init__(self):
|
|
self.calls = 0
|
|
|
|
async def incr(self, _key):
|
|
self.calls += 1
|
|
return self.calls
|
|
|
|
async def expire(self, _key, _seconds):
|
|
return True
|
|
|
|
async def ttl(self, _key):
|
|
return 60
|
|
|
|
|
|
def test_ai_stats_accepts_website_scope(monkeypatch):
|
|
captured_queries = []
|
|
|
|
class CaptureDb:
|
|
async def execute(self, query):
|
|
captured_queries.append(str(query))
|
|
if len(captured_queries) == 1:
|
|
return DummyScalarResult(0)
|
|
return DummyRowsResult([])
|
|
|
|
asyncio.run(ai_generation.get_ai_stats(CaptureDb(), website_id=9))
|
|
|
|
assert all("items.website_id" in query for query in captured_queries)
|
|
|
|
|
|
def test_ai_prompt_preserves_basis_option_labels():
|
|
prompt = ai_generation.get_prompt_template(
|
|
basis_stem="<p>Basis question?</p>",
|
|
basis_options={
|
|
"A": "Option A",
|
|
"B": "Option B",
|
|
"C": "Option C",
|
|
"D": "Option D",
|
|
"E": "Option E",
|
|
},
|
|
basis_correct="A",
|
|
basis_explanation="<p>Because A.</p>",
|
|
target_level="mudah",
|
|
)
|
|
|
|
assert "Create exactly 5 answer options with labels exactly: A, B, C, D, E" in prompt
|
|
assert '"E": "Option E text"' in prompt
|
|
assert "The correct field must be exactly one of: A, B, C, D, E" in prompt
|
|
|
|
|
|
def test_generated_question_must_match_basis_option_labels():
|
|
basis_item = SimpleNamespace(
|
|
options={
|
|
"A": "Option A",
|
|
"B": "Option B",
|
|
"C": "Option C",
|
|
"D": "Option D",
|
|
"E": "Option E",
|
|
}
|
|
)
|
|
generated = GeneratedQuestion(
|
|
stem="Generated",
|
|
options={
|
|
"A": "Option A",
|
|
"B": "Option B",
|
|
"C": "Option C",
|
|
"D": "Option D",
|
|
},
|
|
correct="A",
|
|
)
|
|
|
|
assert not ai_generation.generated_matches_basis_options(generated, basis_item)
|
|
|
|
|
|
def test_cat_selection_only_serves_active_or_approved_variants():
|
|
compiled = str(
|
|
cat_selection._servable_item_filter().compile(
|
|
compile_kwargs={"literal_binds": True}
|
|
)
|
|
)
|
|
|
|
assert "active" in compiled
|
|
assert "approved" in compiled
|
|
assert "draft" not in compiled
|
|
assert "rejected" not in compiled
|
|
|
|
|
|
def test_production_init_db_skips_create_all(monkeypatch):
|
|
import app.database as database
|
|
|
|
class ExplodingEngine:
|
|
def begin(self):
|
|
raise AssertionError("create_all should not run in production")
|
|
|
|
monkeypatch.setattr(database, "settings", Settings(ENVIRONMENT="production"))
|
|
monkeypatch.setattr(database, "engine", ExplodingEngine())
|
|
|
|
asyncio.run(database.init_db())
|
|
|
|
|
|
def test_rate_limit_uses_redis_and_blocks_when_limit_exceeded(monkeypatch):
|
|
dummy_redis = DummyRedis()
|
|
rate_limit.reset_rate_limit_state()
|
|
monkeypatch.setattr(rate_limit, "_get_redis_client", lambda: dummy_redis)
|
|
|
|
asyncio.run(
|
|
rate_limit.enforce_rate_limit(
|
|
DummyRequest(),
|
|
scope="test.redis",
|
|
max_requests=1,
|
|
window_seconds=60,
|
|
)
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(
|
|
rate_limit.enforce_rate_limit(
|
|
DummyRequest(),
|
|
scope="test.redis",
|
|
max_requests=1,
|
|
window_seconds=60,
|
|
)
|
|
)
|
|
assert exc_info.value.status_code == 429
|
|
|
|
|
|
def test_rate_limit_falls_back_to_memory_when_redis_unavailable(monkeypatch):
|
|
rate_limit.reset_rate_limit_state()
|
|
monkeypatch.setattr(rate_limit, "_get_redis_client", lambda: None)
|
|
|
|
asyncio.run(
|
|
rate_limit.enforce_rate_limit(
|
|
DummyRequest(),
|
|
scope="test.memory",
|
|
max_requests=1,
|
|
window_seconds=60,
|
|
)
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.run(
|
|
rate_limit.enforce_rate_limit(
|
|
DummyRequest(),
|
|
scope="test.memory",
|
|
max_requests=1,
|
|
window_seconds=60,
|
|
)
|
|
)
|
|
assert exc_info.value.status_code == 429
|
|
|
|
|
|
def test_schedule_report_persists_model_row():
|
|
db = DummyDb()
|
|
|
|
schedule_id = asyncio.run(
|
|
schedule_report(
|
|
db,
|
|
report_type="student_performance",
|
|
schedule="daily",
|
|
tryout_ids=["t1"],
|
|
website_id=3,
|
|
recipients=["ops@example.com"],
|
|
export_format="xlsx",
|
|
)
|
|
)
|
|
|
|
assert db.flushed is True
|
|
assert isinstance(db.added[0], ReportScheduleModel)
|
|
assert db.added[0].schedule_id == schedule_id
|
|
assert db.added[0].website_id == 3
|
|
|
|
|
|
def test_schedule_helpers_read_list_and_soft_cancel():
|
|
row = ReportScheduleModel(
|
|
schedule_id="sched-1",
|
|
report_type="student_performance",
|
|
schedule="daily",
|
|
tryout_ids=["t1"],
|
|
website_id=3,
|
|
recipients=["ops@example.com"],
|
|
format="xlsx",
|
|
is_active=True,
|
|
)
|
|
|
|
get_db = DummyDb([DummyScalarResult(row)])
|
|
listed_db = DummyDb([DummyListResult([row])])
|
|
cancel_db = DummyDb([DummyScalarResult(row)])
|
|
|
|
got = asyncio.run(get_scheduled_report(get_db, "sched-1"))
|
|
listed = asyncio.run(list_scheduled_reports(listed_db, website_id=3))
|
|
cancelled = asyncio.run(cancel_scheduled_report(cancel_db, "sched-1"))
|
|
|
|
assert got.schedule_id == "sched-1"
|
|
assert listed[0].website_id == 3
|
|
assert cancelled is True
|
|
assert row.is_active is False
|