-
-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathclient.go
More file actions
432 lines (389 loc) · 13.1 KB
/
Copy pathclient.go
File metadata and controls
432 lines (389 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
package mcp_server //nolint:revive // fine for now
// create an http client that can talk to the mcp server
import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"os"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sirupsen/logrus"
)
const (
MCPClientTypeHTTP = "http"
MCPClientTypeSTDIO = "stdio"
)
type MCPClient interface {
InspectTools() ([]map[string]any, error)
CallToolText(toolName string, args map[string]any) (string, error)
}
func NewMCPClient(clientType string, baseURL string, clientCfgMap map[string]any, logger *logrus.Logger) (MCPClient, error) {
switch clientType {
case MCPClientTypeHTTP:
return newHTTPMCPClient(baseURL, clientCfgMap, logger)
case MCPClientTypeSTDIO:
return newStdioMCPClient(logger)
default:
return nil, fmt.Errorf("unknown client type: %s", clientType)
}
}
//nolint:nestif,gocognit,gocyclo,cyclop,funlen // complex but acceptable for now
func getHTTPClient(logger *logrus.Logger, clientCfgMap map[string]any) (*http.Client, error) {
if clientCfgMap != nil && clientCfgMap["ca_file"] != nil {
logger.Infof("Configuring HTTP client with custom CA certificate")
caFile, isString := clientCfgMap["ca_file"].(string)
if !isString {
return nil, fmt.Errorf("ca_file must be a string")
}
caBytes, err := os.ReadFile(caFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file '%s': %w", caFile, err)
}
logger.Infof("Read CA file '%s' (%d bytes)", caFile, len(caBytes))
// Start from system pool when possible
var caCertPool *x509.CertPool
if sysPool, sysErr := x509.SystemCertPool(); sysErr == nil && sysPool != nil {
caCertPool = sysPool
logger.Debug("Using system cert pool as base")
} else {
caCertPool = x509.NewCertPool()
logger.Debug("System cert pool unavailable; using new pool")
}
// Capture the first certificate (candidate leaf) for promote_leaf_to_ca.
var firstCertRaw []byte
{
tmp := caBytes
for {
var blk *pem.Block
blk, tmp = pem.Decode(tmp)
if blk == nil {
break
}
if blk.Type == "CERTIFICATE" {
firstCertRaw = blk.Bytes
break
}
}
}
if ok := caCertPool.AppendCertsFromPEM(caBytes); !ok {
// Fallback: manual decode to provide diagnostics
logger.Warn("AppendCertsFromPEM returned false; attempting manual PEM decode for diagnostics")
blockCount := 0
validCerts := 0
rest := caBytes
for {
var b *pem.Block
b, rest = pem.Decode(rest)
if b == nil {
break
}
blockCount++
if b.Type == "CERTIFICATE" {
if _, perr := x509.ParseCertificate(b.Bytes); perr == nil {
validCerts++
} else {
logger.Errorf("Failed to parse certificate PEM block %d: %v", blockCount, perr)
}
} else {
logger.Debugf("Ignoring non-certificate PEM block type=%s", b.Type)
}
}
return nil, fmt.Errorf("failed to append CA certificate '%s' into trust store: no valid CERTIFICATE PEM blocks found (blocks=%d, valid=%d)", caFile, blockCount, validCerts)
} else {
logger.Infof("Successfully appended custom CA(s) from '%s'", caFile)
// Added: inspect for CA certificates
rest := caBytes
certBlockIdx := 0
caCount := 0
for {
var b *pem.Block
b, rest = pem.Decode(rest)
if b == nil {
break
}
if b.Type != "CERTIFICATE" {
continue
}
certBlockIdx++
parsed, perr := x509.ParseCertificate(b.Bytes)
if perr != nil {
logger.Debugf("Skipping unparsable certificate block %d: %v", certBlockIdx, perr)
continue
}
if parsed.IsCA {
caCount++
}
}
if caCount == 0 {
logger.Warnf("No CA certificates (IsCA=true) found in '%s'. If this file contains only the server leaf certificate it cannot establish standard trust. Supply the issuing CA (or chain) or enable 'promote_leaf_to_ca'.", caFile)
} else {
logger.Debugf("Detected %d CA certificate(s) in '%s'", caCount, caFile)
}
}
// Read optional flags
promoteLeaf := false
if v, ok := clientCfgMap["promote_leaf_to_ca"]; ok {
b, okb := v.(bool)
if !okb {
return nil, fmt.Errorf("promote_leaf_to_ca must be a boolean")
}
promoteLeaf = b
}
insecureSkipVerify := false
if v, ok := clientCfgMap["insecure_skip_verify"]; ok {
suppliedSkipVerify, isBool := v.(bool)
if !isBool {
return nil, fmt.Errorf("insecure_skip_verify must be a boolean")
}
insecureSkipVerify = suppliedSkipVerify
}
var serverName string
if v, ok := clientCfgMap["server_name"]; ok {
if s, ok2 := v.(string); ok2 {
serverName = s
} else {
return nil, fmt.Errorf("server_name must be a string")
}
}
// If server_name not supplied and base URL host differs from cert common name/SAN (common in IP usage),
// user should supply server_name explicitly; we just log hint.
if serverName == "" {
if rawURL, ok := clientCfgMap["base_url"].(string); ok {
if parsed, perr := url.Parse(rawURL); perr == nil && parsed.Hostname() != "" {
// SNI will default to this hostname; log for clarity.
logger.Debugf("Using implicit SNI server name '%s'", parsed.Hostname())
}
}
} else {
logger.Infof("Using explicit TLS server_name (SNI): %s", serverName)
}
//nolint:gosec // testing client only
tlsConfig := &tls.Config{
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: insecureSkipVerify, // may be overridden below if promoting leaf
ServerName: serverName,
}
// If no CA certs and user wants to promote the leaf, install custom verifier.
if promoteLeaf {
if firstCertRaw == nil {
return nil, fmt.Errorf("promote_leaf_to_ca enabled but no certificate PEM blocks found in '%s'", caFile)
}
// Re-parse to log fingerprint
if leafCert, perr := x509.ParseCertificate(firstCertRaw); perr == nil {
fp := sha256.Sum256(leafCert.Raw)
logger.Warnf("Promoting leaf certificate (CN=%s, SHA256=%X) to trust anchor (non-CA). NOT recommended for production.", leafCert.Subject.CommonName, fp[:8])
} else {
logger.Warnf("Promoting leaf certificate (parse error for fingerprint: %v)", perr)
}
tlsConfig.InsecureSkipVerify = true // we will verify manually
expected := make([]byte, len(firstCertRaw))
copy(expected, firstCertRaw)
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return fmt.Errorf("no server certificates presented")
}
if !bytes.Equal(rawCerts[0], expected) {
return fmt.Errorf("server leaf certificate mismatch with promoted leaf")
}
// Optionally parse for additional sanity
if cert, certParseErr := x509.ParseCertificate(rawCerts[0]); certParseErr == nil {
if serverName != "" && serverName != cert.Subject.CommonName {
// Do hostname verification if a serverName was forced.
if verr := cert.VerifyHostname(serverName); verr != nil {
return fmt.Errorf("hostname verification failed for promoted leaf: %w", verr)
}
}
}
return nil
}
}
tr := &http.Transport{
TLSClientConfig: tlsConfig,
}
// Optionally apply TLS config globally so libraries using http.DefaultClient inherit it.
if v, ok := clientCfgMap["apply_tls_globally"]; ok {
if b, okb := v.(bool); !okb {
return nil, fmt.Errorf("apply_tls_globally must be a boolean")
} else if b {
if defTr, okd := http.DefaultTransport.(*http.Transport); okd {
// Shallow clone to avoid races; copy keeps other fields (proxy, dialer, etc.)
cloned := defTr.Clone()
cloned.TLSClientConfig = tlsConfig
http.DefaultTransport = cloned
logger.Warn("Applied custom TLS config globally (http.DefaultTransport). This affects all outbound HTTP requests in this process.")
} else {
logger.Warn("apply_tls_globally requested but http.DefaultTransport is not *http.Transport; skipped")
}
}
}
return &http.Client{Transport: tr}, nil
}
return http.DefaultClient, nil
}
func newHTTPMCPClient(baseURL string, clientCfgMap map[string]any, logger *logrus.Logger) (MCPClient, error) {
if logger == nil {
logger = logrus.New()
logger.SetLevel(logrus.InfoLevel)
}
httpClient, httpClientErr := getHTTPClient(logger, clientCfgMap)
if httpClientErr != nil {
return nil, fmt.Errorf("error creating HTTP client: %w", httpClientErr)
}
return &httpMCPClient{
baseURL: baseURL,
httpClient: httpClient,
logger: logger,
clientCfg: clientCfgMap,
}, nil
}
type httpMCPClient struct {
baseURL string
httpClient *http.Client
logger *logrus.Logger
clientCfg map[string]any
}
func (c *httpMCPClient) connect() (*mcp.ClientSession, error) {
url := c.baseURL
ctx := context.Background()
// Create the URL for the server.
c.logger.Infof("Connecting to MCP server at %s", url)
// Create an MCP client.
client := mcp.NewClient(&mcp.Implementation{
Name: "stackql-client",
Version: "1.0.0",
}, nil)
// Connect to the server.
return client.Connect(ctx, &mcp.StreamableClientTransport{Endpoint: url}, nil)
}
func (c *httpMCPClient) connectOrDie() *mcp.ClientSession {
session, err := c.connect()
if err != nil {
c.logger.Fatalf("Failed to connect: %v", err)
}
return session
}
func (c *httpMCPClient) InspectTools() ([]map[string]any, error) {
session := c.connectOrDie()
defer session.Close()
c.logger.Infof("Connected to server (session ID: %s)", session.ID())
// First, list available tools.
c.logger.Infof("Listing available tools...")
toolsResult, err := session.ListTools(context.Background(), nil)
if err != nil {
c.logger.Fatalf("Failed to list tools: %v", err)
}
var rv []map[string]any
for _, tool := range toolsResult.Tools {
c.logger.Infof(" - %s: %s\n", tool.Name, tool.Description)
toolInfo := map[string]any{
"name": tool.Name,
"description": tool.Description,
}
rv = append(rv, toolInfo)
}
c.logger.Infof("Client completed successfully")
return rv, nil
}
func (c *httpMCPClient) callTool(toolName string, args map[string]any) (*mcp.CallToolResult, error) {
session := c.connectOrDie()
defer session.Close()
c.logger.Infof("Connected to server (session ID: %s)", session.ID())
c.logger.Infof("Calling tool %s...", toolName)
result, err := session.CallTool(context.Background(), &mcp.CallToolParams{
Name: toolName,
Arguments: args,
})
if err != nil {
c.logger.Errorf("Failed to call tool %s: %v\n", toolName, err)
return result, err
}
c.logger.Infof("Client completed successfully")
return result, nil
}
// CallToolText returns the tool's output formatted for a scripting client.
// See formatToolResult for the contract.
func (c *httpMCPClient) CallToolText(toolName string, args map[string]any) (string, error) {
toolCall, toolCallErr := c.callTool(toolName, args)
if toolCallErr != nil {
return "", toolCallErr
}
return formatToolResult(toolName, toolCall, c.prefersText())
}
// prefersText reports whether the client config requests the rendered text
// content blocks instead of the structured payload (`"prefer_text": true`).
// Useful for exercising / consuming the server's text renderings, eg the
// JSON render option of issue #669.
func (c *httpMCPClient) prefersText() bool {
if c.clientCfg == nil {
return false
}
preferText, isBool := c.clientCfg["prefer_text"].(bool)
return isBool && preferText
}
// formatToolResult shapes a CallToolResult for a scripting client.
// Order of preference:
// 1. The concatenated TextContent blocks, when preferText is set.
// 2. StructuredContent re-marshalled as compact JSON (the typed DTO).
// 3. The concatenated TextContent blocks, if no structured payload is present.
//
// Tool-level errors (IsError == true) are returned as a Go error containing the
// text payload, so a CLI caller exits non-zero and the message ends up on
// stderr - matching the existing transport-error convention.
func formatToolResult(toolName string, toolCall *mcp.CallToolResult, preferText bool) (string, error) {
if toolCall == nil {
return "", fmt.Errorf("tool %s returned no result", toolName)
}
if toolCall.IsError {
return "", fmt.Errorf("tool %s: %s", toolName, extractText(toolCall))
}
if preferText {
return extractText(toolCall), nil
}
if toolCall.StructuredContent != nil {
raw, err := json.Marshal(toolCall.StructuredContent)
if err != nil {
return "", fmt.Errorf("marshal structured content for %s: %w", toolName, err)
}
return string(raw), nil
}
return extractText(toolCall), nil
}
func extractText(toolCall *mcp.CallToolResult) string {
var sb bytes.Buffer
for i, content := range toolCall.Content {
if textContent, ok := content.(*mcp.TextContent); ok {
if i > 0 {
sb.WriteString("\n")
}
sb.WriteString(textContent.Text)
}
}
return sb.String()
}
type stdioMCPClient struct {
logger *logrus.Logger
}
func newStdioMCPClient(logger *logrus.Logger) (MCPClient, error) {
if logger == nil {
logger = logrus.New()
logger.SetLevel(logrus.InfoLevel)
}
return &stdioMCPClient{
logger: logger,
}, nil
}
func (c *stdioMCPClient) InspectTools() ([]map[string]any, error) {
c.logger.Infof("stdio MCP client not implemented yet")
return nil, nil
}
func (c *stdioMCPClient) CallToolText(toolName string, args map[string]any) (string, error) {
c.logger.Infof("stdio MCP client not implemented yet")
return "", nil
}