Skip to content

Commit 9437df9

Browse files
authored
fix: #175 BidiStream SIGTERM causes CPU hot loop and pod stuck in Terminating (#176)
* fix: #175 BidiStream SIGTERM causes CPU hot loop and pod stuck in Terminating * Fix typing isues with python 3.13 * fix: python 3.12 tests/lint * fix: Python 3.11 tests * Add gitterm handler * Add type ignore
1 parent bc12e03 commit 9437df9

File tree

4 files changed

+295
-6
lines changed

4 files changed

+295
-6
lines changed

python/restate/server.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
import asyncio
1414
import logging
15-
from typing import Dict, TypedDict, Literal
15+
import signal
16+
from typing import Dict, Set, TypedDict, Literal
1617

1718
from restate.discovery import compute_discovery_json
1819
from restate.endpoint import Endpoint
@@ -213,7 +214,23 @@ def asgi_app(endpoint: Endpoint) -> RestateAppT:
213214
# Prepare request signer
214215
identity_verifier = PyIdentityVerifier(endpoint.identity_keys)
215216

217+
active_channels: Set[ReceiveChannel] = set()
218+
sigterm_installed = False
219+
220+
def _on_sigterm() -> None:
221+
"""Notify all active receive channels of graceful shutdown."""
222+
for ch in active_channels:
223+
ch.notify_shutdown()
224+
216225
async def app(scope: Scope, receive: Receive, send: Send):
226+
nonlocal sigterm_installed
227+
if not sigterm_installed:
228+
loop = asyncio.get_running_loop()
229+
try:
230+
loop.add_signal_handler(signal.SIGTERM, _on_sigterm)
231+
except (NotImplementedError, RuntimeError):
232+
pass # Windows or non-main thread
233+
sigterm_installed = True
217234
try:
218235
if scope["type"] == "lifespan":
219236
raise LifeSpanNotImplemented()
@@ -265,11 +282,13 @@ async def app(scope: Scope, receive: Receive, send: Send):
265282
# Let us set up restate's execution context for this invocation and handler.
266283
#
267284
receive_channel = ReceiveChannel(receive)
285+
active_channels.add(receive_channel)
268286
try:
269287
await process_invocation_to_completion(
270288
VMWrapper(request_headers), handler, dict(request_headers), receive_channel, send
271289
)
272290
finally:
291+
active_channels.discard(receive_channel)
273292
await receive_channel.close()
274293
except LifeSpanNotImplemented as e:
275294
raise e

python/restate/server_context.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ async def leave(self):
454454
# {'type': 'http.request', 'body': b'', 'more_body': True}
455455
# {'type': 'http.request', 'body': b'', 'more_body': False}
456456
# {'type': 'http.disconnect'}
457+
# Wait for the runtime to explicitly close its side of the input.
458+
# On SIGTERM, the shutdown event unblocks this instead of an arbitrary timeout.
457459
await self.receive.block_until_http_input_closed()
458460
# finally, we close our side
459461
# it is important to do it, after the other side has closed his side,
@@ -545,9 +547,9 @@ async def wrapper(f):
545547
continue
546548
if chunk.get("type") == "http.disconnect":
547549
raise DisconnectedException()
548-
if chunk.get("body", None) is not None:
549-
body = chunk.get("body")
550-
assert isinstance(body, bytes)
550+
# Skip empty body frames to avoid hot loop (see #175)
551+
body: bytes | None = chunk.get("body", None) # type: ignore[assignment]
552+
if body is not None and len(body) > 0:
551553
self.vm.notify_input(body)
552554
if not chunk.get("more_body", False):
553555
self.vm.notify_input_closed()

python/restate/server_types.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ class HTTPRequestEvent(TypedDict):
5858
more_body: bool
5959

6060

61+
class HTTPDisconnectEvent(TypedDict):
62+
"""ASGI Disconnect event"""
63+
64+
type: Literal["http.disconnect"]
65+
66+
6167
class HTTPResponseStartEvent(TypedDict):
6268
"""ASGI Response start event"""
6369

@@ -75,7 +81,7 @@ class HTTPResponseBodyEvent(TypedDict):
7581
more_body: bool
7682

7783

78-
ASGIReceiveEvent = HTTPRequestEvent
84+
ASGIReceiveEvent = Union[HTTPRequestEvent, HTTPDisconnectEvent]
7985

8086

8187
ASGISendEvent = Union[HTTPResponseStartEvent, HTTPResponseBodyEvent]
@@ -158,12 +164,18 @@ async def loop():
158164

159165
async def __call__(self) -> ASGIReceiveEvent | RestateEvent:
160166
"""Get the next message."""
167+
if self._disconnected.is_set() and self._queue.empty():
168+
return {"type": "http.disconnect"}
161169
what = await self._queue.get()
162170
self._queue.task_done()
163171
return what
164172

173+
def notify_shutdown(self) -> None:
174+
"""Signal that a graceful shutdown has been requested (e.g. SIGTERM)."""
175+
self._http_input_closed.set()
176+
165177
async def block_until_http_input_closed(self) -> None:
166-
"""Wait until the HTTP input is closed"""
178+
"""Wait until the HTTP input is closed or a shutdown signal is received."""
167179
await self._http_input_closed.wait()
168180

169181
async def enqueue_restate_event(self, what: RestateEvent):

tests/disconnect_hotloop.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
#
2+
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""
12+
Regression tests for disconnect and SIGTERM shutdown handling.
13+
14+
Covers:
15+
- Hot-loop bug when BidiStream disconnects (empty queue, empty body frames)
16+
- Graceful shutdown via notify_shutdown() unblocking block_until_http_input_closed()
17+
"""
18+
19+
import asyncio
20+
from typing import cast
21+
from unittest.mock import MagicMock
22+
23+
import pytest
24+
25+
from restate.server_types import ASGIReceiveEvent, ReceiveChannel
26+
27+
28+
@pytest.fixture(scope="session")
29+
def anyio_backend():
30+
return "asyncio"
31+
32+
33+
pytestmark = [
34+
pytest.mark.anyio,
35+
]
36+
37+
38+
async def test_receive_channel_returns_disconnect_when_drained():
39+
"""After disconnect, an empty queue should return http.disconnect immediately."""
40+
events = [
41+
{"type": "http.request", "body": b"hello", "more_body": True},
42+
{"type": "http.request", "body": b"", "more_body": False},
43+
{"type": "http.disconnect"},
44+
]
45+
event_iter = iter(events)
46+
47+
async def mock_receive() -> ASGIReceiveEvent:
48+
try:
49+
return cast(ASGIReceiveEvent, next(event_iter))
50+
except StopIteration:
51+
# Block forever — simulates the real ASGI receive after disconnect
52+
await asyncio.Event().wait()
53+
raise RuntimeError("unreachable")
54+
55+
channel = ReceiveChannel(mock_receive)
56+
57+
# Drain all queued events
58+
try:
59+
result1 = await asyncio.wait_for(channel(), timeout=1.0)
60+
assert result1["type"] == "http.request"
61+
62+
result2 = await asyncio.wait_for(channel(), timeout=1.0)
63+
assert result2["type"] == "http.request"
64+
65+
result3 = await asyncio.wait_for(channel(), timeout=1.0)
66+
assert result3["type"] == "http.disconnect"
67+
68+
# Now the queue is drained and _disconnected is set.
69+
# This call should return immediately with a synthetic disconnect,
70+
# NOT block forever.
71+
result4 = await asyncio.wait_for(channel(), timeout=1.0)
72+
assert result4["type"] == "http.disconnect"
73+
finally:
74+
await channel.close()
75+
76+
77+
async def test_receive_channel_does_not_block_after_disconnect():
78+
"""Repeated calls after disconnect should all return promptly."""
79+
events = [
80+
{"type": "http.disconnect"},
81+
]
82+
event_iter = iter(events)
83+
84+
async def mock_receive() -> ASGIReceiveEvent:
85+
try:
86+
return cast(ASGIReceiveEvent, next(event_iter))
87+
except StopIteration:
88+
await asyncio.Event().wait()
89+
raise RuntimeError("unreachable")
90+
91+
channel = ReceiveChannel(mock_receive)
92+
93+
try:
94+
# Consume the real disconnect
95+
result = await asyncio.wait_for(channel(), timeout=1.0)
96+
assert result["type"] == "http.disconnect"
97+
98+
# Subsequent calls should not block
99+
for _ in range(5):
100+
result = await asyncio.wait_for(channel(), timeout=0.5)
101+
assert result["type"] == "http.disconnect"
102+
finally:
103+
await channel.close()
104+
105+
106+
async def test_empty_body_frames_do_not_cause_hotloop():
107+
"""
108+
When the VM returns DoProgressReadFromInput and the chunk has body=b'',
109+
notify_input should NOT be called (it would cause a tight loop).
110+
The loop should exit via DisconnectedException when http.disconnect arrives.
111+
"""
112+
from restate.server_context import ServerInvocationContext, DisconnectedException
113+
from restate.vm import DoProgressReadFromInput
114+
115+
# Build a minimal mock context
116+
vm = MagicMock()
117+
vm.take_output.return_value = None
118+
vm.do_progress.return_value = DoProgressReadFromInput()
119+
120+
handler = MagicMock()
121+
invocation = MagicMock()
122+
send = MagicMock()
123+
124+
events = [
125+
{"type": "http.request", "body": b"", "more_body": True},
126+
{"type": "http.request", "body": b"", "more_body": False},
127+
{"type": "http.disconnect"},
128+
]
129+
event_iter = iter(events)
130+
131+
async def mock_receive() -> ASGIReceiveEvent:
132+
try:
133+
return cast(ASGIReceiveEvent, next(event_iter))
134+
except StopIteration:
135+
await asyncio.Event().wait()
136+
raise RuntimeError("unreachable")
137+
138+
receive_channel = ReceiveChannel(mock_receive)
139+
140+
ctx = ServerInvocationContext.__new__(ServerInvocationContext)
141+
ctx.vm = vm
142+
ctx.handler = handler
143+
ctx.invocation = invocation
144+
ctx.send = send
145+
ctx.receive = receive_channel
146+
ctx.run_coros_to_execute = {}
147+
ctx.tasks = MagicMock()
148+
149+
try:
150+
with pytest.raises(DisconnectedException):
151+
await asyncio.wait_for(
152+
ctx.create_poll_or_cancel_coroutine([0]),
153+
timeout=2.0,
154+
)
155+
156+
# notify_input should never have been called with empty bytes
157+
for call in vm.notify_input.call_args_list:
158+
arg = call[0][0]
159+
assert len(arg) > 0, f"notify_input called with empty bytes: {arg!r}"
160+
finally:
161+
await receive_channel.close()
162+
163+
164+
# ---- Shutdown / SIGTERM tests ----
165+
166+
167+
async def test_block_until_http_input_closed_returns_on_normal_close():
168+
"""block_until_http_input_closed returns when the runtime closes its input."""
169+
events = [
170+
{"type": "http.request", "body": b"data", "more_body": True},
171+
{"type": "http.request", "body": b"", "more_body": False},
172+
]
173+
event_iter = iter(events)
174+
175+
async def mock_receive() -> ASGIReceiveEvent:
176+
try:
177+
return cast(ASGIReceiveEvent, next(event_iter))
178+
except StopIteration:
179+
await asyncio.Event().wait()
180+
raise RuntimeError("unreachable")
181+
182+
channel = ReceiveChannel(mock_receive)
183+
try:
184+
# Should return promptly once more_body=False is received
185+
await asyncio.wait_for(channel.block_until_http_input_closed(), timeout=1.0)
186+
finally:
187+
await channel.close()
188+
189+
190+
async def test_block_until_http_input_closed_returns_on_shutdown():
191+
"""block_until_http_input_closed returns when notify_shutdown() is called,
192+
even if the runtime never closes its input."""
193+
194+
async def mock_receive() -> ASGIReceiveEvent:
195+
# Never sends any events — simulates the runtime not closing its side
196+
await asyncio.Event().wait()
197+
raise RuntimeError("unreachable")
198+
199+
channel = ReceiveChannel(mock_receive)
200+
try:
201+
# Schedule shutdown after a short delay
202+
async def trigger_shutdown():
203+
await asyncio.sleep(0.05)
204+
channel.notify_shutdown()
205+
206+
asyncio.create_task(trigger_shutdown())
207+
208+
# Should return promptly due to shutdown, NOT block forever
209+
await asyncio.wait_for(channel.block_until_http_input_closed(), timeout=1.0)
210+
finally:
211+
await channel.close()
212+
213+
214+
async def test_notify_shutdown_is_idempotent():
215+
"""Calling notify_shutdown() multiple times does not raise."""
216+
217+
async def mock_receive() -> ASGIReceiveEvent:
218+
await asyncio.Event().wait()
219+
raise RuntimeError("unreachable")
220+
221+
channel = ReceiveChannel(mock_receive)
222+
try:
223+
channel.notify_shutdown()
224+
channel.notify_shutdown() # should not raise
225+
226+
# Should return immediately since shutdown is already set
227+
await asyncio.wait_for(channel.block_until_http_input_closed(), timeout=0.5)
228+
finally:
229+
await channel.close()
230+
231+
232+
async def test_shutdown_unblocks_concurrent_waiters():
233+
"""Multiple concurrent waiters on block_until_http_input_closed
234+
should all be unblocked by a single notify_shutdown()."""
235+
236+
async def mock_receive() -> ASGIReceiveEvent:
237+
await asyncio.Event().wait()
238+
raise RuntimeError("unreachable")
239+
240+
channel = ReceiveChannel(mock_receive)
241+
try:
242+
results = []
243+
244+
async def waiter(idx: int):
245+
await channel.block_until_http_input_closed()
246+
results.append(idx)
247+
248+
tasks = [asyncio.create_task(waiter(i)) for i in range(3)]
249+
250+
await asyncio.sleep(0.05)
251+
channel.notify_shutdown()
252+
253+
await asyncio.wait_for(asyncio.gather(*tasks), timeout=1.0)
254+
assert sorted(results) == [0, 1, 2]
255+
finally:
256+
await channel.close()

0 commit comments

Comments
 (0)