diff --git a/.vscode/launch.json b/.vscode/launch.json index 2308cfec65..fb5b50b338 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -34,7 +34,10 @@ "module": "blueapi", "args": "--config ${input:config_path} serve", "env": { - "OTLP_EXPORT_ENABLED": "false" + "OTLP_EXPORT_ENABLED": "false", + "EPICS_CA_NAME_SERVERS": "127.0.0.1:9064", + "EPICS_PVA_NAME_SERVERS": "127.0.0.1:9075", + "EPICS_CA_ADDR_LIST": "127.0.0.1:9064" }, }, { diff --git a/pyproject.toml b/pyproject.toml index 659779994a..baceb70eb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ filterwarnings = ["error", "ignore::DeprecationWarning"] # Doctest python code in docs, python code in src docstrings, test functions in tests testpaths = "docs src tests" asyncio_mode = "auto" -timeout = 3 +timeout = 100 [tool.coverage.run] patch = ["subprocess"] diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index b107f7b2b2..61f439d5f6 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -272,3 +272,15 @@ def get_access_token(self): def sync_auth_flow(self, request): request.headers["Authorization"] = f"Bearer {self.get_access_token()}" yield request + + +class OPAClient: # placeholder until https://jira.diamond.ac.uk/browse/ACQP-550 is done + def do_some_checks(self, task_request) -> bool: + return True + + def admin(self): + return False + + +def get_opa_client() -> OPAClient: # placeholder + return OPAClient() diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index a53c46885a..857a6b82b2 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -37,6 +37,7 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface +from blueapi.service.authentication import OPAClient, get_opa_client from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum @@ -166,6 +167,50 @@ def inner(request: Request, access_token: str = Depends(oauth_scheme)): TRACER = get_tracer("interface") +def submit_permission( + opa: Annotated[OPAClient, Depends(get_opa_client)], + task_request: TaskRequest, +): + allowed = opa.do_some_checks(task_request) + + if not allowed: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + +def access_task_permission( + opa: Annotated[OPAClient, Depends(get_opa_client)], + request: Request, + task_id: str, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + task = runner.run(interface.get_task_by_id, task_id) + + if not opa.admin() and ( + access_token + and task + and access_token.get("fedid") != task.task.metadata.get("user") + ): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + + +# start_task_permission is used when there is WorkerTask +def start_task_permission( + opa: Annotated[OPAClient, Depends(get_opa_client)], + request: Request, + task: WorkerTask, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + if not task.task_id: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="No task id provided", + ) + access_task_permission(opa, request, task.task_id, runner) + + async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -291,6 +336,7 @@ def submit_task( request: Request, response: Response, task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: """Submit a task to the worker.""" @@ -299,10 +345,7 @@ def submit_task( access_token: dict[str, Any] | None = getattr( request.state, "decoded_access_token", None ) - if access_token: - user: str = access_token.get("fedid", "Unknown") - else: - user = "Unknown" + user = access_token.get("fedid") if access_token else None task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) response.headers["Location"] = f"{request.url}/{task_id}" @@ -336,6 +379,7 @@ def submit_task( @start_as_current_span(TRACER, "task_id") def delete_submitted_task( task_id: str, + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) @@ -353,6 +397,7 @@ def validate_task_status(v: str) -> TaskStatusEnum: @secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @start_as_current_span(TRACER) def get_tasks( + request: Request, runner: Annotated[WorkerDispatcher, Depends(_runner)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: @@ -373,6 +418,14 @@ def get_tasks( tasks = runner.run(interface.get_tasks_by_status, desired_status) else: tasks = runner.run(interface.get_tasks) + + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + user = access_token.get("fedid") if access_token else None + + tasks = [t for t in tasks if t.task.metadata.get("user") == user] + return TasksListResponse(tasks=tasks) @@ -390,6 +443,7 @@ def get_tasks( def set_active_task( request: Request, task: WorkerTask, + _: Annotated[None, Depends(start_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerTask: """Set a task to active status, the worker should begin it as soon as possible. @@ -420,6 +474,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]: @start_as_current_span(TRACER, "task_id") def get_task( task_id: str, + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TrackableTask: """Retrieve a task""" @@ -495,8 +550,11 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt ) @start_as_current_span(TRACER, "state_change_request.new_state") def set_state( + request: Request, state_change_request: StateChangeRequest, response: Response, + opa: Annotated[OPAClient, Depends(get_opa_client)], + # _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ @@ -523,6 +581,15 @@ def set_state( current_state in _ALLOWED_TRANSITIONS and new_state in _ALLOWED_TRANSITIONS[current_state] ): + active = runner.run(interface.get_active_task) + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + user = access_token.get("fedid") if access_token else None + + if not opa.admin() and active and active.task.metadata.get("user") != user: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + if new_state == WorkerState.PAUSED: runner.run(interface.pause_worker, state_change_request.defer) elif new_state == WorkerState.RUNNING: diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index c1d3b6a957..bf0a6a9977 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -251,7 +251,7 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: response = client.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"}) + mock_runner.run.assert_called_with(submit_task, task, {"user": None}) assert response.json() == {"task_id": task_id}