From ee6618e674caec1fba7ceef849717c111b66b9d7 Mon Sep 17 00:00:00 2001 From: Ricardo Fearing <9965014+rfearing@users.noreply.github.com> Date: Fri, 25 Apr 2025 15:55:19 +0000 Subject: [PATCH 1/3] Combine tools --- pkg/github/notifications.go | 215 +++++++++++++----------------------- 1 file changed, 79 insertions(+), 136 deletions(-) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index d7252e39..9684ff55 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -127,50 +127,19 @@ func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFun } } -// markNotificationRead creates a tool to mark a notification as read. -func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("mark_notification_read", - mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")), - mcp.WithString("threadID", +// ManageNotifications creates a tool to manage notifications (mark as read, mark all as read, or mark as done). +func ManageNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("manage_notifications", + mcp.WithDescription(t("TOOL_MANAGE_NOTIFICATIONS_DESCRIPTION", "Manage notifications (mark as read, mark all as read, or mark as done)")), + mcp.WithString("action", mcp.Required(), - mcp.Description("The ID of the notification thread"), + mcp.Description("The action to perform: 'mark_read', 'mark_all_read', or 'mark_done'"), + ), + mcp.WithString("threadID", + mcp.Description("The ID of the notification thread (required for 'mark_read' and 'mark_done')"), ), - ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - client, err := getclient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - threadID, err := requiredParam[string](request, "threadID") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - resp, err := client.Activity.MarkThreadRead(ctx, threadID) - if err != nil { - return nil, fmt.Errorf("failed to mark notification as read: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil - } - - return mcp.NewToolResultText("Notification marked as read"), nil - } -} - -// MarkAllNotificationsRead creates a tool to mark all notifications as read. -func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("mark_all_notifications_read", - mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")), mcp.WithString("lastReadAt", - mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"), + mcp.Description("Describes the last point that notifications were checked (optional, for 'mark_all_read'). Default: Now"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -179,122 +148,96 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - lastReadAt, err := OptionalStringParam(request, "lastReadAt") + action, err := requiredParam[string](request, "action") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - var markReadOptions github.Timestamp - if lastReadAt != "" { - lastReadTime, err := time.Parse(time.RFC3339, lastReadAt) + switch action { + case "mark_read": + threadID, err := requiredParam[string](request, "threadID") if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil - } - markReadOptions = github.Timestamp{ - Time: lastReadTime, + return mcp.NewToolResultError(err.Error()), nil } - } - - resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions) - if err != nil { - return nil, fmt.Errorf("failed to mark all notifications as read: %w", err) - } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + resp, err := client.Activity.MarkThreadRead(ctx, threadID) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to mark notification as read: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil - } - - return mcp.NewToolResultText("All notifications marked as read"), nil - } -} - -// GetNotificationThread creates a tool to get a specific notification thread. -func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_notification_thread", - mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")), - mcp.WithString("threadID", - mcp.Required(), - mcp.Description("The ID of the notification thread"), - ), - ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - threadID, err := requiredParam[string](request, "threadID") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - thread, resp, err := client.Activity.GetThread(ctx, threadID) - if err != nil { - return nil, fmt.Errorf("failed to get notification thread: %w", err) - } - defer func() { _ = resp.Body.Close() }() + return mcp.NewToolResultText("Notification marked as read"), nil - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + case "mark_done": + threadIDStr, err := requiredParam[string](request, "threadID") if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return mcp.NewToolResultError(err.Error()), nil } - return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil - } - - r, err := json.Marshal(thread) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - return mcp.NewToolResultText(string(r)), nil - } -} + threadID, err := strconv.ParseInt(threadIDStr, 10, 64) + if err != nil { + return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil + } -// markNotificationDone creates a tool to mark a notification as done. -func MarkNotificationDone(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("mark_notification_done", - mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_DONE_DESCRIPTION", "Mark a notification as done")), - mcp.WithString("threadID", - mcp.Required(), - mcp.Description("The ID of the notification thread"), - ), - ), - func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - client, err := getclient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + resp, err := client.Activity.MarkThreadDone(ctx, threadID) + if err != nil { + return nil, fmt.Errorf("failed to mark notification as done: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil + } - threadIDStr, err := requiredParam[string](request, "threadID") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + return mcp.NewToolResultText("Notification marked as done"), nil - threadID, err := strconv.ParseInt(threadIDStr, 10, 64) - if err != nil { - return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil - } + case "mark_all_read": + lastReadAt, err := OptionalStringParam(request, "lastReadAt") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - resp, err := client.Activity.MarkThreadDone(ctx, threadID) - if err != nil { - return nil, fmt.Errorf("failed to mark notification as done: %w", err) - } - defer func() { _ = resp.Body.Close() }() + var markReadOptions github.Timestamp + if lastReadAt != "" { + lastReadTime, err := time.Parse(time.RFC3339, lastReadAt) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil + } + markReadOptions = github.Timestamp{ + Time: lastReadTime, + } + } - if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to mark all notifications as read: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil - } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil + } + + return mcp.NewToolResultText("All notifications marked as read"), nil - return mcp.NewToolResultText("Notification marked as done"), nil + default: + return mcp.NewToolResultError("Invalid action: must be 'mark_read', 'mark_all_read', or 'mark_done'"), nil + } } } From 569711f54db4c2bab6de7e61a8f23d9d67fd0a6e Mon Sep 17 00:00:00 2001 From: Ricardo Fearing Date: Fri, 25 Apr 2025 12:38:32 -0400 Subject: [PATCH 2/3] Update tools.go --- pkg/github/tools.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index fd0f231b..4d4889a8 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -81,14 +81,10 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools"). AddReadTools( - - toolsets.NewServerTool(MarkNotificationRead(getClient, t)), - toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)), - toolsets.NewServerTool(MarkNotificationDone(getClient, t)), + toolsets.NewServerTool(GetNotifications(getClient, t)), ). AddWriteTools( - toolsets.NewServerTool(GetNotifications(getClient, t)), - toolsets.NewServerTool(GetNotificationThread(getClient, t)), + toolsets.NewServerTool(ManageNotifications(getClient, t)), ) // Keep experiments alive so the system doesn't error out when it's always enabled From a5550e434af8ccb4914ef1d83c1a91f681e8efef Mon Sep 17 00:00:00 2001 From: Ricardo Fearing <9965014+rfearing@users.noreply.github.com> Date: Fri, 25 Apr 2025 12:54:51 -0400 Subject: [PATCH 3/3] Update pkg/github/notifications.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pkg/github/notifications.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 9684ff55..2c2e1a13 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -198,7 +198,7 @@ func ManageNotifications(getClient GetClientFn, t translations.TranslationHelper if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil + return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil } return mcp.NewToolResultText("Notification marked as done"), nil