diff --git a/api/tests/unit/app/test_unit_app_routers.py b/api/tests/unit/app/test_unit_app_routers.py index f2505c0d5ecc..1fe9a8d18901 100644 --- a/api/tests/unit/app/test_unit_app_routers.py +++ b/api/tests/unit/app/test_unit_app_routers.py @@ -1,5 +1,8 @@ +from unittest import mock + import pytest from django.db import models +from django.db.models.options import Options from app import routers @@ -16,14 +19,13 @@ def test_analytics_router_db_for_read__given_app_label__returns_expected_db( expected_db: str | None, ) -> None: # Given - class AnalyticsModel(models.Model): - class Meta: - app_label = given_app_label - + mock_model = mock.MagicMock(spec=models.Model) + mock_model._meta = mock.MagicMock(spec=Options) + mock_model._meta.app_label = given_app_label router = routers.AnalyticsRouter() # When - db = router.db_for_read(AnalyticsModel) + db = router.db_for_read(mock_model) # Then assert db == expected_db @@ -41,14 +43,13 @@ def test_analytics_router_db_for_write__given_app_label__returns_expected_db( expected_db: str | None, ) -> None: # Given - class MyModel(models.Model): - class Meta: - app_label = model_app_label - + mock_model = mock.MagicMock(spec=models.Model) + mock_model._meta = mock.MagicMock(spec=Options) + mock_model._meta.app_label = model_app_label router = routers.AnalyticsRouter() # When - db = router.db_for_write(MyModel) + db = router.db_for_write(mock_model) # Then assert db == expected_db @@ -67,18 +68,16 @@ def test_analytics_router_allow_relation__given_app_labels__returns_expected( expected: bool | None, ) -> None: # Given - class MyModel1(models.Model): - class Meta: - app_label = model1_app_label - - class MyModel2(models.Model): - class Meta: - app_label = model2_app_label - + mock_instance1 = mock.MagicMock(spec=models.Model) + mock_instance1._meta = mock.MagicMock(spec=Options) + mock_instance1._meta.app_label = model1_app_label + mock_instance2 = mock.MagicMock(spec=models.Model) + mock_instance2._meta = mock.MagicMock(spec=Options) + mock_instance2._meta.app_label = model2_app_label router = routers.AnalyticsRouter() # When - result = router.allow_relation(MyModel1(), MyModel2()) + result = router.allow_relation(mock_instance1, mock_instance2) # Then assert result == expected