Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 47b546a

Browse files
authored
feat(mcp): add LocalAI endpoint to stream live results of the agent (#7274)
* feat(mcp): add LocalAI endpoint to stream live results of the agent Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * wip Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Refactoring Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * MCP UX integration Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Enhance UX Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Support also non-SSE Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent a09d49d commit 47b546a

File tree

7 files changed

+1188
-105
lines changed

7 files changed

+1188
-105
lines changed

‎core/config/model_config.go‎

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/mudler/LocalAI/core/schema"
1010
"github.com/mudler/LocalAI/pkg/downloader"
1111
"github.com/mudler/LocalAI/pkg/functions"
12+
"github.com/mudler/cogito"
1213
"gopkg.in/yaml.v3"
1314
)
1415

@@ -668,3 +669,40 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
668669

669670
return true
670671
}
672+
673+
// BuildCogitoOptions generates cogito options from the model configuration
674+
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
675+
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
676+
cogitoOpts := []cogito.Option{
677+
cogito.WithIterations(3), // default to 3 iterations
678+
cogito.WithMaxAttempts(3), // default to 3 attempts
679+
cogito.WithForceReasoning(),
680+
}
681+
682+
// Apply agent configuration options
683+
if c.Agent.EnableReasoning {
684+
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
685+
}
686+
687+
if c.Agent.EnablePlanning {
688+
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
689+
}
690+
691+
if c.Agent.EnableMCPPrompts {
692+
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
693+
}
694+
695+
if c.Agent.EnablePlanReEvaluator {
696+
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
697+
}
698+
699+
if c.Agent.MaxIterations != 0 {
700+
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
701+
}
702+
703+
if c.Agent.MaxAttempts != 0 {
704+
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
705+
}
706+
707+
return cogitoOpts
708+
}

‎core/http/app.go‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ func API(application *application.Application) (*echo.Echo, error) {
205205
opcache = services.NewOpCache(application.GalleryService())
206206
}
207207

208-
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
208+
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator())
209209
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
210210
if !application.ApplicationConfig().DisableWebUI {
211211
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)

‎core/http/endpoints/localai/mcp.go‎

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
package localai
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"strings"
9+
"time"
10+
11+
"github.com/labstack/echo/v4"
12+
"github.com/mudler/LocalAI/core/config"
13+
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
14+
"github.com/mudler/LocalAI/core/http/middleware"
15+
"github.com/mudler/LocalAI/core/schema"
16+
"github.com/mudler/LocalAI/core/templates"
17+
"github.com/mudler/LocalAI/pkg/model"
18+
"github.com/mudler/cogito"
19+
"github.com/rs/zerolog/log"
20+
)
21+
22+
// MCP SSE Event Types
23+
type MCPReasoningEvent struct {
24+
Type string `json:"type"`
25+
Content string `json:"content"`
26+
}
27+
28+
type MCPToolCallEvent struct {
29+
Type string `json:"type"`
30+
Name string `json:"name"`
31+
Arguments map[string]interface{} `json:"arguments"`
32+
Reasoning string `json:"reasoning"`
33+
}
34+
35+
type MCPToolResultEvent struct {
36+
Type string `json:"type"`
37+
Name string `json:"name"`
38+
Result string `json:"result"`
39+
}
40+
41+
type MCPStatusEvent struct {
42+
Type string `json:"type"`
43+
Message string `json:"message"`
44+
}
45+
46+
type MCPAssistantEvent struct {
47+
Type string `json:"type"`
48+
Content string `json:"content"`
49+
}
50+
51+
type MCPErrorEvent struct {
52+
Type string `json:"type"`
53+
Message string `json:"message"`
54+
}
55+
56+
// MCPStreamEndpoint is the SSE streaming endpoint for MCP chat completions
57+
// @Summary Stream MCP chat completions with reasoning, tool calls, and results
58+
// @Param request body schema.OpenAIRequest true "query params"
59+
// @Success 200 {object} schema.OpenAIResponse "Response"
60+
// @Router /v1/mcp/chat/completions [post]
61+
func MCPStreamEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
62+
return func(c echo.Context) error {
63+
ctx := c.Request().Context()
64+
created := int(time.Now().Unix())
65+
66+
// Handle Correlation
67+
id := c.Request().Header.Get("X-Correlation-ID")
68+
if id == "" {
69+
id = fmt.Sprintf("mcp-%d", time.Now().UnixNano())
70+
}
71+
72+
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
73+
if !ok || input.Model == "" {
74+
return echo.ErrBadRequest
75+
}
76+
77+
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
78+
if !ok || config == nil {
79+
return echo.ErrBadRequest
80+
}
81+
82+
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
83+
return fmt.Errorf("no MCP servers configured")
84+
}
85+
86+
// Get MCP config from model config
87+
remote, stdio, err := config.MCP.MCPConfigFromYAML()
88+
if err != nil {
89+
return fmt.Errorf("failed to get MCP config: %w", err)
90+
}
91+
92+
// Check if we have tools in cache, or we have to have an initial connection
93+
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
94+
if err != nil {
95+
return fmt.Errorf("failed to get MCP sessions: %w", err)
96+
}
97+
98+
if len(sessions) == 0 {
99+
return fmt.Errorf("no working MCP servers found")
100+
}
101+
102+
// Build fragment from messages
103+
fragment := cogito.NewEmptyFragment()
104+
for _, message := range input.Messages {
105+
fragment = fragment.AddMessage(message.Role, message.StringContent)
106+
}
107+
108+
port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:]
109+
apiKey := ""
110+
if len(appConfig.ApiKeys) > 0 {
111+
apiKey = appConfig.ApiKeys[0]
112+
}
113+
114+
ctxWithCancellation, cancel := context.WithCancel(ctx)
115+
defer cancel()
116+
117+
// TODO: instead of connecting to the API, we should just wire this internally
118+
// and act like completion.go.
119+
// We can do this as cogito expects an interface and we can create one that
120+
// we satisfy to just call internally ComputeChoices
121+
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
122+
123+
// Build cogito options using the consolidated method
124+
cogitoOpts := config.BuildCogitoOptions()
125+
cogitoOpts = append(
126+
cogitoOpts,
127+
cogito.WithContext(ctxWithCancellation),
128+
cogito.WithMCPs(sessions...),
129+
)
130+
// Check if streaming is requested
131+
toStream := input.Stream
132+
133+
if !toStream {
134+
// Non-streaming mode: execute synchronously and return JSON response
135+
cogitoOpts = append(
136+
cogitoOpts,
137+
cogito.WithStatusCallback(func(s string) {
138+
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
139+
}),
140+
cogito.WithReasoningCallback(func(s string) {
141+
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
142+
}),
143+
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
144+
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("reasoning", t.Reasoning).Interface("arguments", t.Arguments).Msg("[model agent] Tool call")
145+
return true
146+
}),
147+
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
148+
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("result", t.Result).Interface("tool_arguments", t.ToolArguments).Msg("[model agent] Tool call result")
149+
}),
150+
)
151+
152+
f, err := cogito.ExecuteTools(
153+
defaultLLM, fragment,
154+
cogitoOpts...,
155+
)
156+
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
157+
return err
158+
}
159+
160+
f, err = defaultLLM.Ask(ctxWithCancellation, f)
161+
if err != nil {
162+
return err
163+
}
164+
165+
resp := &schema.OpenAIResponse{
166+
ID: id,
167+
Created: created,
168+
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
169+
Choices: []schema.Choice{{Message: &schema.Message{Role: "assistant", Content: &f.LastMessage().Content}}},
170+
Object: "chat.completion",
171+
}
172+
173+
jsonResult, _ := json.Marshal(resp)
174+
log.Debug().Msgf("Response: %s", jsonResult)
175+
176+
// Return the prediction in the response body
177+
return c.JSON(200, resp)
178+
}
179+
180+
// Streaming mode: use SSE
181+
// Set up SSE headers
182+
c.Response().Header().Set("Content-Type", "text/event-stream")
183+
c.Response().Header().Set("Cache-Control", "no-cache")
184+
c.Response().Header().Set("Connection", "keep-alive")
185+
c.Response().Header().Set("X-Correlation-ID", id)
186+
187+
// Create channel for streaming events
188+
events := make(chan interface{})
189+
ended := make(chan error, 1)
190+
191+
// Set up callbacks for streaming
192+
statusCallback := func(s string) {
193+
events <- MCPStatusEvent{
194+
Type: "status",
195+
Message: s,
196+
}
197+
}
198+
199+
reasoningCallback := func(s string) {
200+
events <- MCPReasoningEvent{
201+
Type: "reasoning",
202+
Content: s,
203+
}
204+
}
205+
206+
toolCallCallback := func(t *cogito.ToolChoice) bool {
207+
events <- MCPToolCallEvent{
208+
Type: "tool_call",
209+
Name: t.Name,
210+
Arguments: t.Arguments,
211+
Reasoning: t.Reasoning,
212+
}
213+
return true
214+
}
215+
216+
toolCallResultCallback := func(t cogito.ToolStatus) {
217+
events <- MCPToolResultEvent{
218+
Type: "tool_result",
219+
Name: t.Name,
220+
Result: t.Result,
221+
}
222+
}
223+
224+
cogitoOpts = append(cogitoOpts,
225+
cogito.WithStatusCallback(statusCallback),
226+
cogito.WithReasoningCallback(reasoningCallback),
227+
cogito.WithToolCallBack(toolCallCallback),
228+
cogito.WithToolCallResultCallback(toolCallResultCallback),
229+
)
230+
231+
// Execute tools in a goroutine
232+
go func() {
233+
defer close(events)
234+
235+
f, err := cogito.ExecuteTools(
236+
defaultLLM, fragment,
237+
cogitoOpts...,
238+
)
239+
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
240+
events <- MCPErrorEvent{
241+
Type: "error",
242+
Message: fmt.Sprintf("Failed to execute tools: %v", err),
243+
}
244+
ended <- err
245+
return
246+
}
247+
248+
// Get final response
249+
f, err = defaultLLM.Ask(ctxWithCancellation, f)
250+
if err != nil {
251+
events <- MCPErrorEvent{
252+
Type: "error",
253+
Message: fmt.Sprintf("Failed to get response: %v", err),
254+
}
255+
ended <- err
256+
return
257+
}
258+
259+
// Stream final assistant response
260+
content := f.LastMessage().Content
261+
events <- MCPAssistantEvent{
262+
Type: "assistant",
263+
Content: content,
264+
}
265+
266+
ended <- nil
267+
}()
268+
269+
// Stream events to client
270+
LOOP:
271+
for {
272+
select {
273+
case <-ctx.Done():
274+
// Context was cancelled (client disconnected or request cancelled)
275+
log.Debug().Msgf("Request context cancelled, stopping stream")
276+
cancel()
277+
break LOOP
278+
case event := <-events:
279+
if event == nil {
280+
// Channel closed
281+
break LOOP
282+
}
283+
eventData, err := json.Marshal(event)
284+
if err != nil {
285+
log.Debug().Msgf("Failed to marshal event: %v", err)
286+
continue
287+
}
288+
log.Debug().Msgf("Sending event: %s", string(eventData))
289+
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(eventData))
290+
if err != nil {
291+
log.Debug().Msgf("Sending event failed: %v", err)
292+
cancel()
293+
return err
294+
}
295+
c.Response().Flush()
296+
case err := <-ended:
297+
if err == nil {
298+
// Send done signal
299+
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
300+
c.Response().Flush()
301+
break LOOP
302+
}
303+
log.Error().Msgf("Stream ended with error: %v", err)
304+
errorEvent := MCPErrorEvent{
305+
Type: "error",
306+
Message: err.Error(),
307+
}
308+
errorData, marshalErr := json.Marshal(errorEvent)
309+
if marshalErr != nil {
310+
fmt.Fprintf(c.Response().Writer, "data: {\"type\":\"error\",\"message\":\"Internal error\"}\n\n")
311+
} else {
312+
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
313+
}
314+
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
315+
c.Response().Flush()
316+
return nil
317+
}
318+
}
319+
320+
log.Debug().Msgf("Stream ended")
321+
return nil
322+
}
323+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /