diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a0ca751bd..e165001f2 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -57,6 +57,8 @@ async def __call__( async def _default_message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: + if isinstance(message, Exception): + logger.warning("Unhandled exception in message handler: %s", message) await anyio.lowlevel.checkpoint() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9f3dd5e0b..963e4e33c 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -47,7 +47,7 @@ # Reconnection defaults DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry -MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up +MAX_RECONNECTION_ATTEMPTS = 5 # Max retry attempts before giving up class StreamableHTTPError(Exception): @@ -197,7 +197,9 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: event_source.response.raise_for_status() logger.debug("GET SSE connection established") + received_events = False async for sse in event_source.aiter_sse(): + received_events = True # Track last event ID for reconnection if sse.id: last_event_id = sse.id @@ -207,8 +209,9 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: await self._handle_sse_event(sse, read_stream_writer) - # Stream ended normally (server closed) - reset attempt counter - attempt = 0 + # Only reset attempts if we actually received events; + # empty connections count toward MAX_RECONNECTION_ATTEMPTS + attempt = 0 if received_events else attempt + 1 except Exception: # pragma: lax no cover logger.debug("GET stream error", exc_info=True) @@ -364,12 +367,18 @@ async def _handle_sse_response( await response.aclose() return # Normal completion, no reconnect needed except Exception: - logger.debug("SSE stream ended", exc_info=True) # pragma: no cover + logger.debug("SSE stream error", exc_info=True) - # Stream ended without response - reconnect if we received an event with ID - if last_event_id is not None: # pragma: no branch + # Stream ended without a complete response — attempt reconnection if possible + if last_event_id is not None: logger.info("SSE stream disconnected, reconnecting...") - await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) + if await self._handle_reconnection(ctx, last_event_id, retry_interval_ms): + return # Reconnection delivered the response + + # No response delivered — unblock the waiting request with an error + error_data = ErrorData(code=INTERNAL_ERROR, message="SSE stream ended without a response") + error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data)) + await ctx.read_stream_writer.send(error_msg) async def _handle_reconnection( self, @@ -377,12 +386,17 @@ async def _handle_reconnection( last_event_id: str, retry_interval_ms: int | None = None, attempt: int = 0, - ) -> None: - """Reconnect with Last-Event-ID to resume stream after server disconnect.""" + ) -> bool: + """Reconnect with Last-Event-ID to resume stream after server disconnect. + + Returns: + True if the response was successfully delivered, False if max + reconnection attempts were exceeded without delivering a response. + """ # Bail if max retries exceeded - if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover + if attempt >= MAX_RECONNECTION_ATTEMPTS: logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") - return + return False # Always wait - use server value or default delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS @@ -419,15 +433,15 @@ async def _handle_reconnection( ) if is_complete: await event_source.response.aclose() - return + return True - # Stream ended again without response - reconnect again (reset attempt counter) + # Stream ended again without response - reconnect again logger.info("SSE stream disconnected, reconnecting...") - await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) + return await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, attempt + 1) except Exception as e: # pragma: no cover logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID - await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) + return await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) async def post_writer( self, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 42b1a3698..83fa25d3f 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -29,7 +29,14 @@ from mcp import MCPError, types from mcp.client.session import ClientSession -from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client +from mcp.client.streamable_http import ( + MAX_RECONNECTION_ATTEMPTS, + StreamableHTTPTransport, + streamable_http_client, +) +from mcp.client.streamable_http import ( + RequestContext as TransportRequestContext, +) from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, @@ -2247,3 +2254,79 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( assert "content-type" in headers_data assert headers_data["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_sse_read_timeout_propagates_error(basic_server: None, basic_server_url: str): + """SSE read timeout should propagate MCPError instead of hanging.""" + # Create client with very short SSE read timeout + short_timeout = httpx.Timeout(30.0, read=0.5) + async with httpx.AsyncClient(timeout=short_timeout, follow_redirects=True) as http_client: + async with streamable_http_client(f"{basic_server_url}/mcp", http_client=http_client) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + # Read a "slow" resource that takes 2s — longer than our 0.5s read timeout + with pytest.raises(MCPError): # pragma: no branch + with anyio.fail_after(10): # pragma: no branch + await session.read_resource("slow://test") + + +@pytest.mark.anyio +async def test_sse_error_when_reconnection_exhausted( + event_server: tuple[SimpleEventStore, str], + monkeypatch: pytest.MonkeyPatch, +): + """When SSE stream closes after events and reconnection fails, MCPError is raised.""" + _, server_url = event_server + + async def _always_fail_reconnection( + self: Any, ctx: Any, last_event_id: Any, retry_interval_ms: Any = None, attempt: int = 0 + ) -> bool: + return False + + monkeypatch.setattr(StreamableHTTPTransport, "_handle_reconnection", _always_fail_reconnection) + + async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + # tool_with_stream_close sends a priming event (setting last_event_id), + # then closes the SSE stream. With reconnection patched to fail, + # _handle_sse_response falls through to send the error. + with pytest.raises(MCPError): # pragma: no branch + with anyio.fail_after(10): # pragma: no branch + await session.call_tool("tool_with_stream_close", {}) + + +@pytest.mark.anyio +async def test_handle_reconnection_returns_false_on_max_attempts(): + """_handle_reconnection returns False when max attempts exceeded.""" + transport = StreamableHTTPTransport(url="http://localhost:9999/mcp") + + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + message = JSONRPCRequest(jsonrpc="2.0", id=42, method="tools/call", params={"name": "test"}) + session_message = SessionMessage(message) + + ctx = TransportRequestContext( + client=httpx.AsyncClient(), + session_id="test-session", + session_message=session_message, + metadata=None, + read_stream_writer=read_stream_writer, + ) + + try: + with anyio.fail_after(5): + result = await transport._handle_reconnection( + ctx, last_event_id="evt-1", retry_interval_ms=None, attempt=MAX_RECONNECTION_ATTEMPTS + ) + assert result is False + finally: + await read_stream_writer.aclose() + await read_stream.aclose() + await ctx.client.aclose()