diff --git a/backend/package.json b/backend/package.json index b609af47..b6100e28 100644 --- a/backend/package.json +++ b/backend/package.json @@ -76,6 +76,7 @@ "@typescript-eslint/eslint-plugin": "^8.53.1", "@typescript-eslint/parser": "^8.53.1", "bun-types": "^1.3.6", + "cookie-parser": "^1.4.7", "drizzle-kit": "^0.31.8", "eslint": "^9.39.2", "eslint-config-prettier": "^10.1.8", diff --git a/backend/src/app.module.ts b/backend/src/app.module.ts index c11cda7e..0c751013 100644 --- a/backend/src/app.module.ts +++ b/backend/src/app.module.ts @@ -25,6 +25,7 @@ import { IntegrationsModule } from './integrations/integrations.module'; import { SchedulesModule } from './schedules/schedules.module'; import { AnalyticsModule } from './analytics/analytics.module'; import { McpModule } from './mcp/mcp.module'; +import { StudioMcpModule } from './studio-mcp/studio-mcp.module'; import { ApiKeysModule } from './api-keys/api-keys.module'; import { WebhooksModule } from './webhooks/webhooks.module'; @@ -49,6 +50,7 @@ const coreModules = [ McpServersModule, McpGroupsModule, McpModule, + StudioMcpModule, ]; const testingModules = process.env.NODE_ENV === 'production' ? [] : [TestingSupportModule]; diff --git a/backend/src/auth/auth.guard.ts b/backend/src/auth/auth.guard.ts index 1dfb9e95..1d5b966f 100644 --- a/backend/src/auth/auth.guard.ts +++ b/backend/src/auth/auth.guard.ts @@ -124,6 +124,7 @@ export class AuthGuard implements CanActivate { roles: ['MEMBER'], // API keys have MEMBER role by default isAuthenticated: true, provider: 'api-key', + apiKeyPermissions: apiKey.permissions, }; } } diff --git a/backend/src/auth/types.ts b/backend/src/auth/types.ts index c76e7b71..45af28b5 100644 --- a/backend/src/auth/types.ts +++ b/backend/src/auth/types.ts @@ -1,11 +1,18 @@ export type AuthRole = 'ADMIN' | 'MEMBER'; +export interface ApiKeyPermissions { + workflows: { run: boolean; list: boolean; read: boolean }; + runs: { read: boolean; cancel: boolean }; +} + export interface AuthContext { userId: string | null; organizationId: string | null; roles: AuthRole[]; isAuthenticated: boolean; provider: string; + /** Present only when authenticated via API key. */ + apiKeyPermissions?: ApiKeyPermissions; } export const DEFAULT_ROLES: AuthRole[] = ['ADMIN', 'MEMBER']; diff --git a/backend/src/studio-mcp/__tests__/studio-mcp.controller.spec.ts b/backend/src/studio-mcp/__tests__/studio-mcp.controller.spec.ts new file mode 100644 index 00000000..e1151d22 --- /dev/null +++ b/backend/src/studio-mcp/__tests__/studio-mcp.controller.spec.ts @@ -0,0 +1,222 @@ +import { describe, it, expect, beforeEach, jest } from 'bun:test'; +import { StudioMcpController } from '../studio-mcp.controller'; +import type { StudioMcpService } from '../studio-mcp.service'; +import type { AuthContext } from '../../auth/types'; +import type { Request, Response } from 'express'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; + +// Access private sessions map for assertions +type SessionsMap = Map< + string, + { transport: unknown; userId: string | null; organizationId: string | null } +>; +function getSessions(controller: StudioMcpController): SessionsMap { + return (controller as unknown as { sessions: SessionsMap }).sessions; +} + +function createMockRes(): Response & { _status: number; _json: unknown } { + const res = { + _status: 200, + _json: null, + status(code: number) { + res._status = code; + return res; + }, + json(body: unknown) { + res._json = body; + return res; + }, + on: jest.fn(), + } as unknown as Response & { _status: number; _json: unknown }; + return res; +} + +function createMockReq( + overrides: Partial & { auth?: AuthContext } = {}, +): Request & { auth?: AuthContext } { + return { + method: 'POST', + headers: {}, + header: jest.fn().mockReturnValue(undefined), + body: {}, + ...overrides, + } as unknown as Request & { auth?: AuthContext }; +} + +describe('StudioMcpController', () => { + let controller: StudioMcpController; + let mcpService: StudioMcpService; + + const authUser1: AuthContext = { + userId: 'user-1', + organizationId: 'org-1', + roles: ['MEMBER'], + isAuthenticated: true, + provider: 'api-key', + apiKeyPermissions: { + workflows: { run: true, list: true, read: true }, + runs: { read: true, cancel: true }, + }, + }; + + const authUser2: AuthContext = { + userId: 'user-2', + organizationId: 'org-2', + roles: ['MEMBER'], + isAuthenticated: true, + provider: 'api-key', + apiKeyPermissions: { + workflows: { run: true, list: true, read: true }, + runs: { read: true, cancel: true }, + }, + }; + + beforeEach(() => { + mcpService = { + createServer: jest.fn().mockReturnValue(new McpServer({ name: 'test', version: '1.0.0' })), + } as unknown as StudioMcpService; + + controller = new StudioMcpController(mcpService); + }); + + it('rejects unauthenticated requests with 401', async () => { + const req = createMockReq({ auth: undefined }); + const res = createMockRes(); + + await controller.handleMcp(req, res); + + expect(res._status).toBe(401); + expect(res._json).toEqual({ + error: 'Authentication required. Use Bearer sk_live_* API key.', + }); + }); + + it('rejects requests without session ID and without initialize body with 400', async () => { + const req = createMockReq({ + auth: authUser1, + method: 'POST', + headers: {}, + body: { jsonrpc: '2.0', method: 'tools/list', id: 1 }, + }); + const res = createMockRes(); + + await controller.handleMcp(req, res); + + expect(res._status).toBe(400); + }); + + it('returns 404 for unknown session ID', async () => { + const req = createMockReq({ + auth: authUser1, + headers: { 'mcp-session-id': 'nonexistent-session' }, + }); + const res = createMockRes(); + + await controller.handleMcp(req, res); + + expect(res._status).toBe(404); + expect(res._json).toEqual({ error: 'Session not found or expired' }); + }); + + describe('session identity binding', () => { + it('rejects session reuse from different user with 403', async () => { + // Manually insert a session owned by user-1 + const sessions = getSessions(controller); + const mockTransport = { handleRequest: jest.fn() }; + sessions.set('test-session-id', { + transport: mockTransport, + userId: authUser1.userId, + organizationId: authUser1.organizationId, + }); + + // User-2 tries to use user-1's session + const req = createMockReq({ + auth: authUser2, + method: 'POST', + headers: { 'mcp-session-id': 'test-session-id' }, + body: { jsonrpc: '2.0', method: 'tools/list', id: 1 }, + }); + const res = createMockRes(); + + await controller.handleMcp(req, res); + + expect(res._status).toBe(403); + expect(res._json).toEqual({ error: 'Session belongs to a different principal' }); + expect(mockTransport.handleRequest).not.toHaveBeenCalled(); + }); + + it('rejects session reuse from different org with 403', async () => { + const sessions = getSessions(controller); + const mockTransport = { handleRequest: jest.fn() }; + sessions.set('test-session-id', { + transport: mockTransport, + userId: authUser1.userId, + organizationId: authUser1.organizationId, + }); + + // Same user ID but different org + const crossOrgAuth: AuthContext = { + ...authUser1, + organizationId: 'different-org', + }; + const req = createMockReq({ + auth: crossOrgAuth, + method: 'POST', + headers: { 'mcp-session-id': 'test-session-id' }, + body: { jsonrpc: '2.0', method: 'tools/list', id: 1 }, + }); + const res = createMockRes(); + + await controller.handleMcp(req, res); + + expect(res._status).toBe(403); + expect(mockTransport.handleRequest).not.toHaveBeenCalled(); + }); + + it('allows session reuse from same principal', async () => { + const sessions = getSessions(controller); + const mockTransport = { handleRequest: jest.fn() }; + sessions.set('test-session-id', { + transport: mockTransport, + userId: authUser1.userId, + organizationId: authUser1.organizationId, + }); + + const req = createMockReq({ + auth: authUser1, + method: 'POST', + headers: { 'mcp-session-id': 'test-session-id' }, + body: { jsonrpc: '2.0', method: 'tools/list', id: 1 }, + }); + const res = createMockRes(); + + await controller.handleMcp(req, res); + + // Should have forwarded to the transport, not returned an error + expect(res._status).toBe(200); // not changed to 403 or 404 + expect(mockTransport.handleRequest).toHaveBeenCalled(); + }); + + it('cleans up session on DELETE from same principal', async () => { + const sessions = getSessions(controller); + const mockTransport = { handleRequest: jest.fn() }; + sessions.set('test-session-id', { + transport: mockTransport, + userId: authUser1.userId, + organizationId: authUser1.organizationId, + }); + + const req = createMockReq({ + auth: authUser1, + method: 'DELETE', + headers: { 'mcp-session-id': 'test-session-id' }, + }); + const res = createMockRes(); + + await controller.handleMcp(req, res); + + expect(sessions.has('test-session-id')).toBe(false); + expect(mockTransport.handleRequest).toHaveBeenCalled(); + }); + }); +}); diff --git a/backend/src/studio-mcp/__tests__/studio-mcp.service.spec.ts b/backend/src/studio-mcp/__tests__/studio-mcp.service.spec.ts new file mode 100644 index 00000000..9c1ff1bf --- /dev/null +++ b/backend/src/studio-mcp/__tests__/studio-mcp.service.spec.ts @@ -0,0 +1,366 @@ +import { describe, it, expect, beforeEach, jest } from 'bun:test'; +import { StudioMcpService } from '../studio-mcp.service'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import type { AuthContext } from '../../auth/types'; +import type { WorkflowsService } from '../../workflows/workflows.service'; + +// Helper to access private _registeredTools on McpServer (plain object at runtime) +type ToolHandler = (...args: unknown[]) => unknown; +type RegisteredToolsMap = Record; +function getRegisteredTools(server: McpServer): RegisteredToolsMap { + return (server as unknown as { _registeredTools: RegisteredToolsMap })._registeredTools; +} + +describe('StudioMcpService Unit Tests', () => { + let service: StudioMcpService; + let workflowsService: WorkflowsService; + + const mockAuthContext: AuthContext = { + userId: 'test-user-id', + organizationId: 'test-org-id', + roles: ['ADMIN'], + isAuthenticated: true, + provider: 'test', + }; + + beforeEach(() => { + workflowsService = { + list: jest.fn().mockResolvedValue([]), + findById: jest.fn().mockResolvedValue(null), + run: jest.fn().mockResolvedValue({ + runId: 'test-run-id', + workflowId: 'test-workflow-id', + status: 'RUNNING', + workflowVersion: 1, + }), + listRuns: jest.fn().mockResolvedValue({ runs: [] }), + getRunStatus: jest.fn().mockResolvedValue({ + runId: 'test-run-id', + workflowId: 'test-workflow-id', + status: 'RUNNING', + startedAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + }), + getRunResult: jest.fn().mockResolvedValue({}), + cancelRun: jest.fn().mockResolvedValue(undefined), + } as unknown as WorkflowsService; + + service = new StudioMcpService(workflowsService); + }); + + it('should be defined', () => { + expect(service).toBeDefined(); + }); + + describe('createServer', () => { + it('returns an McpServer instance', () => { + const server = service.createServer(mockAuthContext); + + expect(server).toBeDefined(); + expect(server).toBeInstanceOf(McpServer); + }); + + it('registers all 9 expected tools', () => { + const server = service.createServer(mockAuthContext); + const registeredTools = getRegisteredTools(server); + + expect(registeredTools).toBeDefined(); + expect(Object.keys(registeredTools).length).toBe(9); + + const toolNames = Object.keys(registeredTools).sort(); + expect(toolNames).toEqual([ + 'cancel_run', + 'get_component', + 'get_run_result', + 'get_run_status', + 'get_workflow', + 'list_components', + 'list_runs', + 'list_workflows', + 'run_workflow', + ]); + }); + + it('workflow tools use auth context passed at creation time', async () => { + const server = service.createServer(mockAuthContext); + const registeredTools = getRegisteredTools(server); + const listWorkflowsTool = registeredTools['list_workflows']; + + expect(listWorkflowsTool).toBeDefined(); + await listWorkflowsTool.handler({}); + + expect(workflowsService.list).toHaveBeenCalledWith(mockAuthContext); + }); + + it('get_workflow tool uses auth context passed at creation time', async () => { + const workflowId = '11111111-1111-4111-8111-111111111111'; + (workflowsService.findById as jest.Mock).mockResolvedValue({ + id: workflowId, + name: 'Test Workflow', + description: 'Test description', + }); + + const server = service.createServer(mockAuthContext); + const registeredTools = getRegisteredTools(server); + const getWorkflowTool = registeredTools['get_workflow']; + + expect(getWorkflowTool).toBeDefined(); + await getWorkflowTool.handler({ workflowId }); + + expect(workflowsService.findById).toHaveBeenCalledWith(workflowId, mockAuthContext); + }); + + it('run_workflow tool uses auth context passed at creation time', async () => { + const workflowId = '11111111-1111-4111-8111-111111111111'; + const inputs = { key: 'value' }; + + const server = service.createServer(mockAuthContext); + const registeredTools = getRegisteredTools(server); + const runWorkflowTool = registeredTools['run_workflow']; + + expect(runWorkflowTool).toBeDefined(); + await runWorkflowTool.handler({ workflowId, inputs }); + + expect(workflowsService.run).toHaveBeenCalledWith( + workflowId, + { inputs, versionId: undefined }, + mockAuthContext, + { + trigger: { + type: 'api', + sourceId: mockAuthContext.userId, + label: 'Studio MCP', + }, + }, + ); + }); + + it('list_runs tool uses auth context passed at creation time', async () => { + const server = service.createServer(mockAuthContext); + const registeredTools = getRegisteredTools(server); + const listRunsTool = registeredTools['list_runs']; + + expect(listRunsTool).toBeDefined(); + await listRunsTool.handler({}); + + expect(workflowsService.listRuns).toHaveBeenCalledWith(mockAuthContext, { + workflowId: undefined, + status: undefined, + limit: 20, + }); + }); + + it('get_run_status tool uses auth context passed at creation time', async () => { + const runId = 'test-run-id'; + + const server = service.createServer(mockAuthContext); + const registeredTools = getRegisteredTools(server); + const getRunStatusTool = registeredTools['get_run_status']; + + expect(getRunStatusTool).toBeDefined(); + await getRunStatusTool.handler({ runId }); + + expect(workflowsService.getRunStatus).toHaveBeenCalledWith(runId, undefined, mockAuthContext); + }); + + it('get_run_result tool uses auth context passed at creation time', async () => { + const runId = 'test-run-id'; + + const server = service.createServer(mockAuthContext); + const registeredTools = getRegisteredTools(server); + const getRunResultTool = registeredTools['get_run_result']; + + expect(getRunResultTool).toBeDefined(); + await getRunResultTool.handler({ runId }); + + expect(workflowsService.getRunResult).toHaveBeenCalledWith(runId, undefined, mockAuthContext); + }); + + it('cancel_run tool uses auth context passed at creation time', async () => { + const runId = 'test-run-id'; + + const server = service.createServer(mockAuthContext); + const registeredTools = getRegisteredTools(server); + const cancelRunTool = registeredTools['cancel_run']; + + expect(cancelRunTool).toBeDefined(); + await cancelRunTool.handler({ runId }); + + expect(workflowsService.cancelRun).toHaveBeenCalledWith(runId, undefined, mockAuthContext); + }); + + it('component tools do not require auth context', async () => { + const server = service.createServer(mockAuthContext); + const registeredTools = getRegisteredTools(server); + const listComponentsTool = registeredTools['list_components']; + const getComponentTool = registeredTools['get_component']; + + expect(listComponentsTool).toBeDefined(); + expect(getComponentTool).toBeDefined(); + + const listResult = await listComponentsTool.handler({}); + expect(listResult).toBeDefined(); + + const getResult = await getComponentTool.handler({ + componentId: 'core.workflow.entrypoint', + }); + expect(getResult).toBeDefined(); + }); + + describe('API key permission gating', () => { + const restrictedAuth: AuthContext = { + userId: 'api-key-id', + organizationId: 'test-org-id', + roles: ['MEMBER'], + isAuthenticated: true, + provider: 'api-key', + apiKeyPermissions: { + workflows: { run: false, list: true, read: true }, + runs: { read: true, cancel: false }, + }, + }; + + it('allows list_workflows when workflows.list is true', async () => { + const server = service.createServer(restrictedAuth); + const tools = getRegisteredTools(server); + const result = (await tools['list_workflows'].handler({})) as { isError?: boolean }; + expect(result.isError).toBeUndefined(); + }); + + it('denies run_workflow when workflows.run is false', async () => { + const server = service.createServer(restrictedAuth); + const tools = getRegisteredTools(server); + const result = (await tools['run_workflow'].handler({ + workflowId: '11111111-1111-4111-8111-111111111111', + })) as { isError?: boolean; content: { text: string }[] }; + expect(result.isError).toBe(true); + expect(result.content[0].text).toContain('workflows.run'); + }); + + it('denies cancel_run when runs.cancel is false', async () => { + const server = service.createServer(restrictedAuth); + const tools = getRegisteredTools(server); + const result = (await tools['cancel_run'].handler({ + runId: 'test-run-id', + })) as { isError?: boolean; content: { text: string }[] }; + expect(result.isError).toBe(true); + expect(result.content[0].text).toContain('runs.cancel'); + }); + + it('allows get_run_status when runs.read is true', async () => { + const server = service.createServer(restrictedAuth); + const tools = getRegisteredTools(server); + const result = (await tools['get_run_status'].handler({ + runId: 'test-run-id', + })) as { isError?: boolean }; + expect(result.isError).toBeUndefined(); + }); + + it('allows all tools when no apiKeyPermissions (non-API-key auth)', async () => { + const server = service.createServer(mockAuthContext); // no apiKeyPermissions + const tools = getRegisteredTools(server); + + // All workflow/run tools should work without permission errors + const listResult = (await tools['list_workflows'].handler({})) as { isError?: boolean }; + expect(listResult.isError).toBeUndefined(); + + const runResult = (await tools['run_workflow'].handler({ + workflowId: '11111111-1111-4111-8111-111111111111', + })) as { isError?: boolean }; + expect(runResult.isError).toBeUndefined(); + + const cancelResult = (await tools['cancel_run'].handler({ + runId: 'test-run-id', + })) as { isError?: boolean }; + expect(cancelResult.isError).toBeUndefined(); + }); + + it('component tools are always allowed regardless of permissions', async () => { + const noPermsAuth: AuthContext = { + ...restrictedAuth, + apiKeyPermissions: { + workflows: { run: false, list: false, read: false }, + runs: { read: false, cancel: false }, + }, + }; + const server = service.createServer(noPermsAuth); + const tools = getRegisteredTools(server); + + const listResult = (await tools['list_components'].handler({})) as { isError?: boolean }; + expect(listResult.isError).toBeUndefined(); + + const getResult = (await tools['get_component'].handler({ + componentId: 'core.workflow.entrypoint', + })) as { isError?: boolean }; + expect(getResult.isError).toBeUndefined(); + }); + + it('denies all 7 gated tools when all permissions are false', async () => { + const noPermsAuth: AuthContext = { + ...restrictedAuth, + apiKeyPermissions: { + workflows: { run: false, list: false, read: false }, + runs: { read: false, cancel: false }, + }, + }; + const server = service.createServer(noPermsAuth); + const tools = getRegisteredTools(server); + + const gatedTools = [ + 'list_workflows', + 'get_workflow', + 'run_workflow', + 'list_runs', + 'get_run_status', + 'get_run_result', + 'cancel_run', + ]; + + for (const toolName of gatedTools) { + const result = (await tools[toolName].handler({ + workflowId: '11111111-1111-4111-8111-111111111111', + runId: 'test-run-id', + })) as { isError?: boolean }; + expect(result.isError).toBe(true); + } + }); + }); + + it('each server instance has isolated auth context', async () => { + const authContext1: AuthContext = { + userId: 'user-1', + organizationId: 'org-1', + roles: ['ADMIN'], + isAuthenticated: true, + provider: 'test', + }; + + const authContext2: AuthContext = { + userId: 'user-2', + organizationId: 'org-2', + roles: ['MEMBER'], + isAuthenticated: true, + provider: 'test', + }; + + const server1 = service.createServer(authContext1); + const server2 = service.createServer(authContext2); + + const registeredTools1 = getRegisteredTools(server1); + const registeredTools2 = getRegisteredTools(server2); + + const listWorkflowsTool1 = registeredTools1['list_workflows']; + const listWorkflowsTool2 = registeredTools2['list_workflows']; + + expect(listWorkflowsTool1).toBeDefined(); + expect(listWorkflowsTool2).toBeDefined(); + + await listWorkflowsTool1.handler({}); + await listWorkflowsTool2.handler({}); + + expect(workflowsService.list).toHaveBeenCalledTimes(2); + expect(workflowsService.list).toHaveBeenNthCalledWith(1, authContext1); + expect(workflowsService.list).toHaveBeenNthCalledWith(2, authContext2); + }); + }); +}); diff --git a/backend/src/studio-mcp/studio-mcp.controller.ts b/backend/src/studio-mcp/studio-mcp.controller.ts new file mode 100644 index 00000000..f7820ba6 --- /dev/null +++ b/backend/src/studio-mcp/studio-mcp.controller.ts @@ -0,0 +1,125 @@ +import { Controller, All, Req, Res, Logger } from '@nestjs/common'; +import { ApiTags, ApiOperation } from '@nestjs/swagger'; +import { randomUUID } from 'node:crypto'; +import type { Request, Response } from 'express'; +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; +import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js'; + +import type { AuthContext } from '../auth/types'; +import { StudioMcpService } from './studio-mcp.service'; + +/** + * Exposes ShipSec Studio as an MCP server for external agents. + * + * Auth: Uses global AuthGuard which validates Bearer sk_live_* API keys. + * Protocol: MCP Streamable HTTP only (POST for messages, GET for server-push, DELETE for session end). + * + * Endpoint: /api/v1/studio-mcp + */ +interface McpSession { + transport: StreamableHTTPServerTransport; + /** Identity of the caller who created this session — used to reject hijacking. */ + userId: string | null; + organizationId: string | null; +} + +@ApiTags('studio-mcp') +@Controller('studio-mcp') +export class StudioMcpController { + private readonly logger = new Logger(StudioMcpController.name); + + // Active session transports keyed by MCP session ID. + // NOTE: In-memory — single-instance design. For horizontal scaling, use sticky sessions. + private readonly sessions = new Map(); + + constructor(private readonly studioMcpService: StudioMcpService) {} + + @All() + @ApiOperation({ summary: 'Studio MCP endpoint (Streamable HTTP) for external agents' }) + async handleMcp(@Req() req: Request & { auth?: AuthContext }, @Res() res: Response) { + const auth = req.auth; + if (!auth?.isAuthenticated) { + return res + .status(401) + .json({ error: 'Authentication required. Use Bearer sk_live_* API key.' }); + } + + const sessionId = req.headers['mcp-session-id'] as string | undefined; + const body = req.body as unknown; + const isPost = req.method === 'POST'; + const isGet = req.method === 'GET'; + const isDelete = req.method === 'DELETE'; + const isInitRequest = + isPost && + (isInitializeRequest(body) || + (Array.isArray(body) && body.some((item) => isInitializeRequest(item)))); + + // ---- Existing session ---- + if (sessionId) { + const session = this.sessions.get(sessionId); + if (!session) { + return res.status(404).json({ error: 'Session not found or expired' }); + } + + // Verify the caller matches the session creator (prevent session hijacking) + if (session.userId !== auth.userId || session.organizationId !== auth.organizationId) { + this.logger.warn( + `Session identity mismatch for ${sessionId}: ` + + `expected user=${session.userId} org=${session.organizationId}, ` + + `got user=${auth.userId} org=${auth.organizationId}`, + ); + return res.status(403).json({ error: 'Session belongs to a different principal' }); + } + + const { transport } = session; + + if (isGet) { + res.on('close', () => { + this.logger.log(`Studio MCP SSE closed for session ${sessionId}`); + this.sessions.delete(sessionId); + }); + // Cast: Express Request extends IncomingMessage; handleRequest accepts it at runtime + void transport.handleRequest(req as any, res as any); + } else if (isDelete) { + this.logger.log(`Studio MCP session terminated: ${sessionId}`); + await transport.handleRequest(req as any, res as any, body); + this.sessions.delete(sessionId); + } else { + await transport.handleRequest(req as any, res as any, body); + } + return; + } + + // ---- New session (initialize) ---- + if (!isInitRequest) { + return res + .status(400) + .json({ error: 'Missing Mcp-Session-Id header. Send an initialize request first.' }); + } + + this.logger.log( + `New Studio MCP session for org=${auth.organizationId}, provider=${auth.provider}`, + ); + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + enableJsonResponse: true, + }); + + const server = this.studioMcpService.createServer(auth); + await server.connect(transport); + + // Handle the initialize request (sends response with Mcp-Session-Id header) + await transport.handleRequest(req as any, res as any, body); + + // Store transport + identity by the session ID generated during initialize + if (transport.sessionId) { + this.sessions.set(transport.sessionId, { + transport, + userId: auth.userId, + organizationId: auth.organizationId, + }); + this.logger.log(`Studio MCP session created: ${transport.sessionId}`); + } + } +} diff --git a/backend/src/studio-mcp/studio-mcp.module.ts b/backend/src/studio-mcp/studio-mcp.module.ts new file mode 100644 index 00000000..81cfe053 --- /dev/null +++ b/backend/src/studio-mcp/studio-mcp.module.ts @@ -0,0 +1,12 @@ +import { Module } from '@nestjs/common'; + +import { WorkflowsModule } from '../workflows/workflows.module'; +import { StudioMcpController } from './studio-mcp.controller'; +import { StudioMcpService } from './studio-mcp.service'; + +@Module({ + imports: [WorkflowsModule], + controllers: [StudioMcpController], + providers: [StudioMcpService], +}) +export class StudioMcpModule {} diff --git a/backend/src/studio-mcp/studio-mcp.service.ts b/backend/src/studio-mcp/studio-mcp.service.ts new file mode 100644 index 00000000..8ac4b92b --- /dev/null +++ b/backend/src/studio-mcp/studio-mcp.service.ts @@ -0,0 +1,412 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { z } from 'zod'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; + +// Ensure all worker components are registered before accessing the registry +import '@shipsec/studio-worker/components'; +import { + componentRegistry, + extractPorts, + isAgentCallable, + getToolSchema, + type CachedComponentMetadata, +} from '@shipsec/component-sdk'; +import type { ExecutionStatus } from '@shipsec/shared'; +import { categorizeComponent } from '../components/utils/categorization'; +import { WorkflowsService, type WorkflowRunSummary } from '../workflows/workflows.service'; +import type { ServiceWorkflowResponse } from '../workflows/dto/workflow-graph.dto'; +import type { AuthContext, ApiKeyPermissions } from '../auth/types'; + +type PermissionPath = + | 'workflows.list' + | 'workflows.read' + | 'workflows.run' + | 'runs.read' + | 'runs.cancel'; + +@Injectable() +export class StudioMcpService { + private readonly logger = new Logger(StudioMcpService.name); + + constructor(private readonly workflowsService: WorkflowsService) {} + + /** + * Check whether the caller's API key permits the given action. + * Non-API-key callers (e.g. internal service tokens) are always allowed. + */ + private checkPermission( + auth: AuthContext, + permission: PermissionPath, + ): + | { allowed: true } + | { allowed: false; error: { content: { type: 'text'; text: string }[]; isError: true } } { + const perms = auth.apiKeyPermissions; + if (!perms) return { allowed: true }; // non-API-key auth → unrestricted + + const [scope, action] = permission.split('.') as [keyof ApiKeyPermissions, string]; + const scopePerms = perms[scope] as Record | undefined; + if (!scopePerms || !scopePerms[action]) { + return { + allowed: false, + error: { + content: [ + { + type: 'text' as const, + text: `Permission denied: API key lacks '${permission}' permission.`, + }, + ], + isError: true, + }, + }; + } + return { allowed: true }; + } + + /** + * Create an MCP server with all Studio tools registered, scoped to the given auth context. + * Uses Streamable HTTP transport only (no legacy SSE). + */ + createServer(auth: AuthContext): McpServer { + const server = new McpServer({ + name: 'shipsec-studio', + version: '1.0.0', + }); + + this.registerTools(server, auth); + + return server; + } + + private registerTools(server: McpServer, auth: AuthContext): void { + this.registerWorkflowTools(server, auth); + this.registerComponentTools(server); + this.registerRunTools(server, auth); + } + + // --------------------------------------------------------------------------- + // Workflow tools + // --------------------------------------------------------------------------- + + private registerWorkflowTools(server: McpServer, auth: AuthContext): void { + server.registerTool( + 'list_workflows', + { + description: + 'List all workflows in the organization. Returns id, name, description, and version info.', + }, + async () => { + const gate = this.checkPermission(auth, 'workflows.list'); + if (!gate.allowed) return gate.error; + try { + const workflows = await this.workflowsService.list(auth); + const summary = workflows.map((w: ServiceWorkflowResponse) => ({ + id: w.id, + name: w.name, + description: w.description ?? null, + currentVersion: w.currentVersion, + currentVersionId: w.currentVersionId, + createdAt: w.createdAt, + updatedAt: w.updatedAt, + })); + return { + content: [{ type: 'text' as const, text: JSON.stringify(summary, null, 2) }], + }; + } catch (error) { + return this.errorResult(error); + } + }, + ); + + server.registerTool( + 'get_workflow', + { + description: + 'Get detailed information about a specific workflow, including its graph (nodes, edges) and runtime input definitions.', + inputSchema: { workflowId: z.string().uuid() }, + }, + async (args: { workflowId: string }) => { + const gate = this.checkPermission(auth, 'workflows.read'); + if (!gate.allowed) return gate.error; + try { + const workflow = await this.workflowsService.findById(args.workflowId, auth); + return { + content: [{ type: 'text' as const, text: JSON.stringify(workflow, null, 2) }], + }; + } catch (error) { + return this.errorResult(error); + } + }, + ); + + server.registerTool( + 'run_workflow', + { + description: + 'Start a workflow execution. Returns the run ID and initial status. Use get_run_status to poll for completion.', + inputSchema: { + workflowId: z.string().uuid(), + inputs: z.record(z.string(), z.unknown()).optional(), + versionId: z.string().uuid().optional(), + }, + }, + async (args: { + workflowId: string; + inputs?: Record; + versionId?: string; + }) => { + const gate = this.checkPermission(auth, 'workflows.run'); + if (!gate.allowed) return gate.error; + try { + const handle = await this.workflowsService.run( + args.workflowId, + { inputs: args.inputs ?? {}, versionId: args.versionId }, + auth, + { + trigger: { + type: 'api', + sourceId: auth.userId ?? 'api-key', + label: 'Studio MCP', + }, + }, + ); + return { + content: [ + { + type: 'text' as const, + text: JSON.stringify( + { + runId: handle.runId, + workflowId: handle.workflowId, + status: handle.status, + workflowVersion: handle.workflowVersion, + }, + null, + 2, + ), + }, + ], + }; + } catch (error) { + return this.errorResult(error); + } + }, + ); + } + + // --------------------------------------------------------------------------- + // Component tools + // --------------------------------------------------------------------------- + + private registerComponentTools(server: McpServer): void { + server.registerTool( + 'list_components', + { + description: + 'List all available workflow components (nodes) with their category, description, and whether they are agent-callable.', + }, + async () => { + try { + const entries = componentRegistry.listMetadata(); + const components = entries.map((entry: CachedComponentMetadata) => { + const def = entry.definition; + const category = categorizeComponent(def); + return { + id: def.id, + name: def.label, + category, + description: def.ui?.description ?? def.docs ?? '', + runner: def.runner?.kind ?? 'inline', + agentCallable: isAgentCallable(def), + inputCount: (entry.inputs ?? []).length, + outputCount: (entry.outputs ?? []).length, + }; + }); + return { + content: [{ type: 'text' as const, text: JSON.stringify(components, null, 2) }], + }; + } catch (error) { + return this.errorResult(error); + } + }, + ); + + server.registerTool( + 'get_component', + { + description: + 'Get detailed information about a specific component, including its full input/output/parameter schemas.', + inputSchema: { componentId: z.string() }, + }, + async (args: { componentId: string }) => { + try { + const entry = componentRegistry.getMetadata(args.componentId); + if (!entry) { + return { + content: [ + { + type: 'text' as const, + text: `Component "${args.componentId}" not found`, + }, + ], + isError: true, + }; + } + const def = entry.definition; + const category = categorizeComponent(def); + const result = { + id: def.id, + name: def.label, + category, + description: def.ui?.description ?? def.docs ?? '', + documentation: def.docs ?? null, + runner: def.runner, + inputs: entry.inputs ?? extractPorts(def.inputs), + outputs: entry.outputs ?? extractPorts(def.outputs), + parameters: entry.parameters ?? [], + agentCallable: isAgentCallable(def), + toolSchema: isAgentCallable(def) ? getToolSchema(def) : null, + examples: def.ui?.examples ?? [], + }; + return { + content: [{ type: 'text' as const, text: JSON.stringify(result, null, 2) }], + }; + } catch (error) { + return this.errorResult(error); + } + }, + ); + } + + // --------------------------------------------------------------------------- + // Run tools + // --------------------------------------------------------------------------- + + private registerRunTools(server: McpServer, auth: AuthContext): void { + server.registerTool( + 'list_runs', + { + description: 'List recent workflow runs. Optionally filter by workflow or status.', + inputSchema: { + workflowId: z.string().uuid().optional(), + status: z + .enum([ + 'RUNNING', + 'COMPLETED', + 'FAILED', + 'CANCELLED', + 'TERMINATED', + 'TIMED_OUT', + 'AWAITING_INPUT', + ]) + .optional(), + limit: z.number().int().positive().max(100).optional(), + }, + }, + async (args: { workflowId?: string; status?: ExecutionStatus; limit?: number }) => { + const gate = this.checkPermission(auth, 'runs.read'); + if (!gate.allowed) return gate.error; + try { + const result = await this.workflowsService.listRuns(auth, { + workflowId: args.workflowId, + status: args.status, + limit: args.limit ?? 20, + }); + const runs = result.runs.map((r: WorkflowRunSummary) => ({ + id: r.id, + workflowId: r.workflowId, + workflowName: r.workflowName, + status: r.status, + startTime: r.startTime, + endTime: r.endTime, + duration: r.duration, + triggerType: r.triggerType, + })); + return { + content: [{ type: 'text' as const, text: JSON.stringify(runs, null, 2) }], + }; + } catch (error) { + return this.errorResult(error); + } + }, + ); + + server.registerTool( + 'get_run_status', + { + description: + 'Get the current status of a workflow run including progress, failures, and timing.', + inputSchema: { runId: z.string() }, + }, + async (args: { runId: string }) => { + const gate = this.checkPermission(auth, 'runs.read'); + if (!gate.allowed) return gate.error; + try { + const status = await this.workflowsService.getRunStatus(args.runId, undefined, auth); + return { + content: [{ type: 'text' as const, text: JSON.stringify(status, null, 2) }], + }; + } catch (error) { + return this.errorResult(error); + } + }, + ); + + server.registerTool( + 'get_run_result', + { + description: 'Get the final result/output of a completed workflow run.', + inputSchema: { runId: z.string() }, + }, + async (args: { runId: string }) => { + const gate = this.checkPermission(auth, 'runs.read'); + if (!gate.allowed) return gate.error; + try { + const result = await this.workflowsService.getRunResult(args.runId, undefined, auth); + return { + content: [{ type: 'text' as const, text: JSON.stringify(result, null, 2) }], + }; + } catch (error) { + return this.errorResult(error); + } + }, + ); + + server.registerTool( + 'cancel_run', + { + description: 'Cancel a running workflow execution.', + inputSchema: { runId: z.string() }, + }, + async (args: { runId: string }) => { + const gate = this.checkPermission(auth, 'runs.cancel'); + if (!gate.allowed) return gate.error; + try { + await this.workflowsService.cancelRun(args.runId, undefined, auth); + return { + content: [ + { + type: 'text' as const, + text: JSON.stringify({ cancelled: true, runId: args.runId }, null, 2), + }, + ], + }; + } catch (error) { + return this.errorResult(error); + } + }, + ); + } + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + private errorResult(error: unknown) { + const message = error instanceof Error ? error.message : String(error); + this.logger.error(`Studio MCP tool error: ${message}`); + return { + content: [{ type: 'text' as const, text: `Error: ${message}` }], + isError: true, + }; + } +} diff --git a/bun.lock b/bun.lock index d994c1f6..16b97dcf 100644 --- a/bun.lock +++ b/bun.lock @@ -14,9 +14,11 @@ }, "devDependencies": { "@ai-sdk/mcp": "^1.0.13", + "@ai-sdk/openai": "^3.0.25", "@modelcontextprotocol/sdk": "^1.25.3", "@types/bun": "^1.3.6", "@types/node": "^24.10.9", + "ai": "^6.0.49", "bun-types": "^1.3.6", "husky": "^9.1.7", "lint-staged": "^16.2.7", @@ -89,6 +91,7 @@ "@typescript-eslint/eslint-plugin": "^8.53.1", "@typescript-eslint/parser": "^8.53.1", "bun-types": "^1.3.6", + "cookie-parser": "^1.4.7", "drizzle-kit": "^0.31.8", "eslint": "^9.39.2", "eslint-config-prettier": "^10.1.8", @@ -1553,7 +1556,9 @@ "cookie": ["cookie@0.7.2", "", {}, "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w=="], - "cookie-signature": ["cookie-signature@1.2.2", "", {}, "sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg=="], + "cookie-parser": ["cookie-parser@1.4.7", "", { "dependencies": { "cookie": "0.7.2", "cookie-signature": "1.0.6" } }, "sha512-nGUvgXnotP3BsjiLX2ypbQnWoGUPIIfHQNZkkC668ntrzGWEZVW70HDEB1qnNGMicPje6EttlIgzo51YSwNQGw=="], + + "cookie-signature": ["cookie-signature@1.0.6", "", {}, "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ=="], "cookiejar": ["cookiejar@2.1.4", "", {}, "sha512-LDx6oHrK+PhzLKJU9j5S7/Y3jM/mUHvD/DeI1WQmJn652iPC5Y4TBzC9l+5OMOXlyTTA+SmVUPm0HQUwpD5Jqw=="], @@ -3339,6 +3344,8 @@ "express/body-parser": ["body-parser@2.2.2", "", { "dependencies": { "bytes": "^3.1.2", "content-type": "^1.0.5", "debug": "^4.4.3", "http-errors": "^2.0.0", "iconv-lite": "^0.7.0", "on-finished": "^2.4.1", "qs": "^6.14.1", "raw-body": "^3.0.1", "type-is": "^2.0.1" } }, "sha512-oP5VkATKlNwcgvxi0vM0p/D3n2C3EReYVX+DNYs5TjZFn/oQt2j+4sVJtSMr18pdRr8wjTcBl6LoV+FUwzPmNA=="], + "express/cookie-signature": ["cookie-signature@1.2.2", "", {}, "sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg=="], + "fast-glob/glob-parent": ["glob-parent@5.1.2", "", { "dependencies": { "is-glob": "^4.0.1" } }, "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow=="], "foreground-child/signal-exit": ["signal-exit@4.1.0", "", {}, "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw=="], @@ -3471,6 +3478,8 @@ "sucrase/commander": ["commander@4.1.1", "", {}, "sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA=="], + "supertest/cookie-signature": ["cookie-signature@1.2.2", "", {}, "sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg=="], + "tailwindcss/postcss-selector-parser": ["postcss-selector-parser@6.1.2", "", { "dependencies": { "cssesc": "^3.0.0", "util-deprecate": "^1.0.2" } }, "sha512-Q8qQfPiZ+THO/3ZrOrO0cJJKfpYCagtMUkXbnEfmgUjwXg6z/WBeOyS9APBBPCTSiDV+s4SwQGu8yFsiMRIudg=="], "tailwindcss/resolve": ["resolve@1.22.11", "", { "dependencies": { "is-core-module": "^2.16.1", "path-parse": "^1.0.7", "supports-preserve-symlinks-flag": "^1.0.0" }, "bin": { "resolve": "bin/resolve" } }, "sha512-RfqAvLnMl313r7c9oclB1HhUEAezcpLjz95wFH4LVuhk9JF/r22qmVP9AMmOU4vMX7Q8pN8jwNg/CSpdFnMjTQ=="], diff --git a/e2e-tests/studio-mcp/studio-mcp-agent.test.ts b/e2e-tests/studio-mcp/studio-mcp-agent.test.ts new file mode 100644 index 00000000..702d409f --- /dev/null +++ b/e2e-tests/studio-mcp/studio-mcp-agent.test.ts @@ -0,0 +1,326 @@ +import { expect, beforeAll, afterAll } from 'bun:test'; +import { createMCPClient, type MCPClient } from '@ai-sdk/mcp'; +import { generateText, stepCountIs } from 'ai'; +import { createOpenAI } from '@ai-sdk/openai'; + +import { + API_BASE, + HEADERS, + e2eDescribe, + e2eTest, + createWorkflow, +} from '../helpers/e2e-harness'; + +interface ApiKeyResponse { + id: string; + plainKey: string; + name: string; + permissions: { + workflows: { run: boolean; list: boolean; read: boolean }; + runs: { read: boolean; cancel: boolean }; + }; +} + +e2eDescribe('Studio MCP: AI SDK Integration', () => { + let apiKeyId: string | null = null; + let plainKey: string | null = null; + let mcpClient: MCPClient | null = null; + let workflowId: string | null = null; + + beforeAll(async () => { + // Create API key for MCP authentication + const res = await fetch(`${API_BASE}/api-keys`, { + method: 'POST', + headers: HEADERS, + body: JSON.stringify({ + name: `e2e-studio-mcp-ai-sdk-${Date.now()}`, + permissions: { + workflows: { run: true, list: true, read: true }, + runs: { read: true, cancel: true }, + }, + }), + }); + + if (!res.ok) { + throw new Error(`Failed to create API key: ${res.status} ${await res.text()}`); + } + + const data = (await res.json()) as ApiKeyResponse; + apiKeyId = data.id; + plainKey = data.plainKey; + + expect(plainKey).toBeDefined(); + expect(plainKey).toMatch(/^sk_live_/); + }); + + afterAll(async () => { + if (mcpClient) { + try { + await mcpClient.close(); + } catch (error) { + console.warn('Error closing MCP client:', error); + } + } + + if (workflowId) { + try { + await fetch(`${API_BASE}/workflows/${workflowId}`, { + method: 'DELETE', + headers: HEADERS, + }); + } catch (error) { + console.warn('Error deleting workflow:', error); + } + } + + if (apiKeyId) { + try { + await fetch(`${API_BASE}/api-keys/${apiKeyId}`, { + method: 'DELETE', + headers: HEADERS, + }); + } catch (error) { + console.warn('Error deleting API key:', error); + } + } + }); + + e2eTest('AI SDK MCP client connects and discovers tools', { timeout: 60000 }, async () => { + expect(plainKey).toBeDefined(); + + mcpClient = await createMCPClient({ + transport: { + type: 'http', + url: `${API_BASE}/studio-mcp`, + headers: { + Authorization: `Bearer ${plainKey}`, + }, + }, + }); + + expect(mcpClient).toBeDefined(); + + const tools = await mcpClient!.tools(); + expect(tools).toBeDefined(); + + const toolNames = Object.keys(tools); + expect(toolNames.length).toBeGreaterThanOrEqual(9); + + const expectedTools = [ + 'list_workflows', + 'get_workflow', + 'run_workflow', + 'list_components', + 'get_component', + 'list_runs', + 'get_run_status', + 'get_run_result', + 'cancel_run', + ]; + + for (const expectedTool of expectedTools) { + expect(toolNames).toContain(expectedTool); + } + }); + + e2eTest( + 'AI SDK agent can use Studio MCP tools via generateText', + { timeout: 120000 }, + async () => { + const ZAI_API_KEY = process.env.ZAI_API_KEY; + + if (!ZAI_API_KEY) { + console.warn('Skipping AI agent test: ZAI_API_KEY not set'); + return; + } + + expect(plainKey).toBeDefined(); + + const client = await createMCPClient({ + transport: { + type: 'http', + url: `${API_BASE}/studio-mcp`, + headers: { + Authorization: `Bearer ${plainKey}`, + }, + }, + }); + + try { + const tools = await client.tools(); + + const openai = createOpenAI({ + baseURL: 'https://api.z.ai/api/coding/paas/v4', + apiKey: ZAI_API_KEY, + }); + + const model = openai.chat('glm-4.7'); + + const response = await generateText({ + model, + tools, + stopWhen: stepCountIs(3), + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'List all available components using the list_components tool and tell me how many there are.', + }, + ], + }, + ], + }); + + // DEBUG: show agent behavior + console.log('\n=== TEST 2: list_components agent ==='); + console.log(`Steps: ${response.steps.length}`); + for (const [i, step] of response.steps.entries()) { + console.log(`\n--- Step ${i + 1} ---`); + if (step.toolCalls?.length) { + for (const tc of step.toolCalls) { + console.log(` Tool call: ${tc.toolName}(${JSON.stringify((tc as Record).input ?? {})})`); + } + } + if (step.toolResults?.length) { + for (const tr of step.toolResults) { + const raw = JSON.stringify((tr as Record).output ?? tr) ?? ''; + console.log(` Tool result: ${raw.slice(0, 500)}`); + } + } + if (step.text) { + console.log(` Text: ${step.text.slice(0, 500)}`); + } + } + console.log(`\nFinal response: ${response.text.slice(0, 1000)}`); + console.log('=== END TEST 2 ===\n'); + + expect(response.steps).toBeDefined(); + expect(response.steps.length).toBeGreaterThan(0); + + const hasToolCalls = response.steps.some( + (step) => step.toolCalls && step.toolCalls.length > 0, + ); + expect(hasToolCalls).toBe(true); + + expect(response.text).toBeDefined(); + expect(response.text.length).toBeGreaterThan(0); + + const lowerText = response.text.toLowerCase(); + const mentionsComponents = lowerText.includes('component') || /\d+/.test(response.text); + expect(mentionsComponents).toBe(true); + } finally { + await client.close(); + } + }, + ); + + e2eTest('AI SDK agent can execute workflow operations', { timeout: 120000 }, async () => { + const ZAI_API_KEY = process.env.ZAI_API_KEY; + + if (!ZAI_API_KEY) { + console.warn('Skipping workflow operations test: ZAI_API_KEY not set'); + return; + } + + expect(plainKey).toBeDefined(); + + const workflow = { + name: `E2E AI SDK MCP Test ${Date.now()}`, + nodes: [ + { + id: 'start', + type: 'core.workflow.entrypoint', + position: { x: 0, y: 0 }, + data: { + label: 'Start', + config: { + params: { + runtimeInputs: [{ id: 'message', label: 'Message', type: 'text' }], + }, + }, + }, + }, + ], + edges: [], + }; + + workflowId = await createWorkflow(workflow); + expect(workflowId).toBeDefined(); + + const client = await createMCPClient({ + transport: { + type: 'http', + url: `${API_BASE}/studio-mcp`, + headers: { + Authorization: `Bearer ${plainKey}`, + }, + }, + }); + + try { + const tools = await client.tools(); + + const openai = createOpenAI({ + baseURL: 'https://api.z.ai/api/coding/paas/v4', + apiKey: ZAI_API_KEY, + }); + + const model = openai.chat('glm-4.7'); + + const response = await generateText({ + model, + tools, + stopWhen: stepCountIs(5), + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: `Run the workflow with ID "${workflowId}" using the input message "Hello from AI SDK test". Then check its status.`, + }, + ], + }, + ], + }); + + // DEBUG: show agent behavior + console.log('\n=== TEST 3: workflow operations agent ==='); + console.log(`Steps: ${response.steps.length}`); + for (const [i, step] of response.steps.entries()) { + console.log(`\n--- Step ${i + 1} ---`); + if (step.toolCalls?.length) { + for (const tc of step.toolCalls) { + console.log(` Tool call: ${tc.toolName}(${JSON.stringify((tc as Record).input ?? {})})`); + } + } + if (step.toolResults?.length) { + for (const tr of step.toolResults) { + const raw = JSON.stringify((tr as Record).output ?? tr) ?? ''; + console.log(` Tool result: ${raw.slice(0, 500)}`); + } + } + if (step.text) { + console.log(` Text: ${step.text.slice(0, 500)}`); + } + } + console.log(`\nFinal response: ${response.text.slice(0, 1000)}`); + console.log('=== END TEST 3 ===\n'); + + expect(response.steps).toBeDefined(); + expect(response.steps.length).toBeGreaterThan(0); + + const allToolCalls = response.steps.flatMap((step) => step.toolCalls || []); + const toolCallNames = allToolCalls.map((call) => call.toolName); + + expect(toolCallNames).toContain('run_workflow'); + + expect(response.text).toBeDefined(); + expect(response.text.length).toBeGreaterThan(0); + } finally { + await client.close(); + } + }); +}); diff --git a/package.json b/package.json index 88f406bf..9a71c68f 100644 --- a/package.json +++ b/package.json @@ -36,9 +36,11 @@ }, "devDependencies": { "@ai-sdk/mcp": "^1.0.13", + "@ai-sdk/openai": "^3.0.25", "@modelcontextprotocol/sdk": "^1.25.3", "@types/bun": "^1.3.6", "@types/node": "^24.10.9", + "ai": "^6.0.49", "bun-types": "^1.3.6", "husky": "^9.1.7", "lint-staged": "^16.2.7",