diff --git a/cmd/mcptools/main.go b/cmd/mcptools/main.go index 66a2eff..ccbdf3a 100644 --- a/cmd/mcptools/main.go +++ b/cmd/mcptools/main.go @@ -50,6 +50,20 @@ var ( errCommandRequired = fmt.Errorf("command to execute is required when using stdio transport") ) +// createClientFunc is the function used to create MCP clients. +// This can be replaced in tests to use a mock transport. +var createClientFunc = func(args []string) (*client.Client, error) { + if len(args) == 0 { + return nil, errCommandRequired + } + + if len(args) == 1 && (strings.HasPrefix(args[0], "http://") || strings.HasPrefix(args[0], "https://")) { + return client.NewHTTP(args[0]), nil + } + + return client.NewStdio(args), nil +} + func main() { cobra.EnableCommandSorting = false @@ -97,18 +111,6 @@ func newVersionCmd() *cobra.Command { } } -func createClient(args []string) (*client.Client, error) { - if len(args) == 0 { - return nil, errCommandRequired - } - - if len(args) == 1 && (strings.HasPrefix(args[0], "http://") || strings.HasPrefix(args[0], "https://")) { - return client.NewHTTP(args[0]), nil - } - - return client.NewStdio(args), nil -} - func processFlags(args []string) []string { parsedArgs := []string{} @@ -155,7 +157,7 @@ func newToolsCmd() *cobra.Command { parsedArgs := processFlags(args) - mcpClient, err := createClient(parsedArgs) + mcpClient, err := createClientFunc(parsedArgs) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Example: mcp tools npx -y @modelcontextprotocol/server-filesystem ~\n") @@ -185,7 +187,7 @@ func newResourcesCmd() *cobra.Command { parsedArgs := processFlags(args) - mcpClient, err := createClient(parsedArgs) + mcpClient, err := createClientFunc(parsedArgs) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Example: mcp resources npx -y @modelcontextprotocol/server-filesystem ~\n") @@ -215,7 +217,7 @@ func newPromptsCmd() *cobra.Command { parsedArgs := processFlags(args) - mcpClient, err := createClient(parsedArgs) + mcpClient, err := createClientFunc(parsedArgs) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) fmt.Fprintf(os.Stderr, "Example: mcp prompts npx -y @modelcontextprotocol/server-filesystem ~\n") @@ -311,7 +313,7 @@ func newCallCmd() *cobra.Command { } } - mcpClient, clientErr := createClient(parsedArgs) + mcpClient, clientErr := createClientFunc(parsedArgs) if clientErr != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", clientErr) os.Exit(1) @@ -403,7 +405,7 @@ func newGetPromptCmd() *cobra.Command { } } - mcpClient, clientErr := createClient(parsedArgs) + mcpClient, clientErr := createClientFunc(parsedArgs) if clientErr != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", clientErr) os.Exit(1) @@ -481,7 +483,7 @@ func newReadResourceCmd() *cobra.Command { } } - mcpClient, clientErr := createClient(parsedArgs) + mcpClient, clientErr := createClientFunc(parsedArgs) if clientErr != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", clientErr) os.Exit(1) @@ -529,7 +531,7 @@ func newShellCmd() *cobra.Command { //nolint:gocyclo os.Exit(1) } - mcpClient, clientErr := createClient(parsedArgs) + mcpClient, clientErr := createClientFunc(parsedArgs) if clientErr != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", clientErr) os.Exit(1) @@ -850,7 +852,7 @@ Available types: - prompt