diff --git a/.gitignore b/.gitignore index 8f39b7b1c87..b18b48f1e1b 100644 --- a/.gitignore +++ b/.gitignore @@ -90,4 +90,4 @@ CHANGELOG.ignore.md # Python bytecode files __pycache__/ *.pyc - +.worktrees/ diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index dcaf8a89e6c..a7975a0814b 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -305,6 +305,19 @@ async fn run_codex_tool_session_inner( .remove(&request_id); break; } + EventMsg::TurnAborted(_) => { + let result = create_call_tool_result_with_thread_id( + thread_id, + "Turn aborted".to_string(), + Some(true), + ); + outgoing.send_response(request_id.clone(), result).await; + running_requests_id_to_codex_uuid + .lock() + .await + .remove(&request_id); + break; + } EventMsg::SessionConfigured(_) => { tracing::error!("unexpected SessionConfigured event"); } @@ -344,7 +357,6 @@ async fn run_codex_tool_session_inner( | EventMsg::WebSearchEnd(_) | EventMsg::GetHistoryEntryResponse(_) | EventMsg::PlanUpdate(_) - | EventMsg::TurnAborted(_) | EventMsg::UserMessage(_) | EventMsg::ShutdownComplete | EventMsg::ViewImageToolCall(_) diff --git a/codex-rs/mcp-server/tests/common/mcp_process.rs b/codex-rs/mcp-server/tests/common/mcp_process.rs index 9a3f076fb19..096c6d4c044 100644 --- a/codex-rs/mcp-server/tests/common/mcp_process.rs +++ b/codex-rs/mcp-server/tests/common/mcp_process.rs @@ -13,6 +13,8 @@ use anyhow::Context; use codex_mcp_server::CodexToolCallParam; use mcp_types::CallToolRequestParams; +use mcp_types::CancelledNotification; +use mcp_types::CancelledNotificationParams; use mcp_types::ClientCapabilities; use mcp_types::Implementation; use mcp_types::InitializeRequestParams; @@ -235,6 +237,23 @@ impl McpProcess { Ok(()) } + pub async fn send_cancelled_notification( + &mut self, + request_id: RequestId, + reason: &str, + ) -> anyhow::Result<()> { + let params = CancelledNotificationParams { + request_id, + reason: Some(reason.to_string()), + }; + self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification { + jsonrpc: JSONRPC_VERSION.into(), + method: CancelledNotification::METHOD.into(), + params: Some(serde_json::to_value(params)?), + })) + .await + } + async fn read_jsonrpc_message(&mut self) -> anyhow::Result { let mut line = String::new(); self.stdout.read_line(&mut line).await?; diff --git a/codex-rs/mcp-server/tests/suite/codex_tool_abort.rs b/codex-rs/mcp-server/tests/suite/codex_tool_abort.rs new file mode 100644 index 00000000000..cecdccb33a6 --- /dev/null +++ b/codex-rs/mcp-server/tests/suite/codex_tool_abort.rs @@ -0,0 +1,128 @@ +use std::path::Path; + +use codex_mcp_server::CodexToolCallParam; +use mcp_types::CallToolResult; +use mcp_types::RequestId; +use tempfile::TempDir; +use tokio::time::timeout; + +use core_test_support::skip_if_no_network; +use mcp_test_support::McpProcess; +use mcp_test_support::create_mock_chat_completions_server; +use mcp_test_support::create_shell_command_sse_response; +use mcp_test_support::to_response; + +// Allow ample time on slower CI or under load to avoid flakes. +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_tool_call_returns_on_turn_aborted() { + skip_if_no_network!(); + + if let Err(err) = tool_call_returns_on_turn_aborted().await { + panic!("failure: {err}"); + } +} + +async fn tool_call_returns_on_turn_aborted() -> anyhow::Result<()> { + let workdir_for_shell_function_call = TempDir::new()?; + let created_filename = "created_by_shell_tool.txt"; + let shell_command = vec![ + "python3".to_string(), + "-c".to_string(), + format!("import pathlib; pathlib.Path('{created_filename}').touch()"), + ]; + + let responses = vec![create_shell_command_sse_response( + shell_command, + Some(workdir_for_shell_function_call.path()), + Some(5_000), + "call1234", + )?]; + + let McpHandle { + process: mut mcp_process, + server: _server, + dir: _dir, + } = create_mcp_process(responses).await?; + + let codex_request_id = mcp_process + .send_codex_tool_call(CodexToolCallParam { + prompt: "run command".to_string(), + ..Default::default() + }) + .await?; + + let _elicitation_request = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_request_message(), + ) + .await??; + + let request_id = RequestId::Integer(codex_request_id); + mcp_process + .send_cancelled_notification(request_id.clone(), "test abort") + .await?; + + let response = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(request_id), + ) + .await??; + + let result: CallToolResult = to_response(response)?; + assert!(result.is_error.unwrap_or(false)); + + Ok(()) +} + +/// This handle is used to ensure that the MockServer and TempDir are not dropped while +/// the McpProcess is still running. +struct McpHandle { + process: McpProcess, + /// Retain the server for the lifetime of the McpProcess. + #[allow(dead_code)] + server: wiremock::MockServer, + /// Retain the temporary directory for the lifetime of the McpProcess. + #[allow(dead_code)] + dir: TempDir, +} + +async fn create_mcp_process(responses: Vec) -> anyhow::Result { + let server = create_mock_chat_completions_server(responses).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + let mut mcp_process = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??; + Ok(McpHandle { + process: mcp_process, + server, + dir: codex_home, + }) +} + +/// Create a Codex config that uses the mock server as the model provider. +/// It also uses `approval_policy = "untrusted"` so that we exercise the +/// elicitation code path for shell commands. +fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "untrusted" +sandbox_policy = "workspace-write" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/mcp-server/tests/suite/mod.rs b/codex-rs/mcp-server/tests/suite/mod.rs index 6b50853b165..bb378150473 100644 --- a/codex-rs/mcp-server/tests/suite/mod.rs +++ b/codex-rs/mcp-server/tests/suite/mod.rs @@ -1 +1,2 @@ mod codex_tool; +mod codex_tool_abort;