08_test/test_db.py
"""Test the model layer."""
import sqlite3
from unittest.mock import patch
import models
import util
SCHEMA = """
CREATE TABLE staff (
staff_id BIGINT,
personal TEXT,
family TEXT
);
"""
INSERT = """
INSERT INTO staff VALUES(?, ?, ?);
"""
STAFF = [
(1, "Catalina", "Moyano"),
(2, "Paloma", "Bellini Ruiz"),
(4, "Paula", "Martinez"),
]
def make_db():
connection = sqlite3.connect(":memory:", detect_types=sqlite3.PARSE_DECLTYPES)
connection.row_factory = util.dict_factory
connection.executescript(SCHEMA)
connection.executemany(INSERT, STAFF)
return connection
def test_can_get_all_staff():
with patch("models.connect", make_db):
result = models.all_staff()
assert len(result) == len(STAFF)
assert {r["staff_id"] for r in result} == {s[0] for s in STAFF}