Skip to content

Combine tools #349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: notifications-tooling
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 79 additions & 136 deletions pkg/github/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,50 +127,19 @@ func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFun
}
}

// markNotificationRead creates a tool to mark a notification as read.
func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("mark_notification_read",
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")),
mcp.WithString("threadID",
// ManageNotifications creates a tool to manage notifications (mark as read, mark all as read, or mark as done).
func ManageNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("manage_notifications",
mcp.WithDescription(t("TOOL_MANAGE_NOTIFICATIONS_DESCRIPTION", "Manage notifications (mark as read, mark all as read, or mark as done)")),
mcp.WithString("action",
mcp.Required(),
mcp.Description("The ID of the notification thread"),
mcp.Description("The action to perform: 'mark_read', 'mark_all_read', or 'mark_done'"),
),
mcp.WithString("threadID",
mcp.Description("The ID of the notification thread (required for 'mark_read' and 'mark_done')"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getclient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

threadID, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

resp, err := client.Activity.MarkThreadRead(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
}

return mcp.NewToolResultText("Notification marked as read"), nil
}
}

// MarkAllNotificationsRead creates a tool to mark all notifications as read.
func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("mark_all_notifications_read",
mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")),
mcp.WithString("lastReadAt",
mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"),
mcp.Description("Describes the last point that notifications were checked (optional, for 'mark_all_read'). Default: Now"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
Expand All @@ -179,122 +148,96 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

lastReadAt, err := OptionalStringParam(request, "lastReadAt")
action, err := requiredParam[string](request, "action")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

var markReadOptions github.Timestamp
if lastReadAt != "" {
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
switch action {
case "mark_read":
Copy link
Preview

Copilot AI Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'mark_read' action does not call any API to mark a notification as read; it directly returns a success message. Please add the appropriate API call to actually mark the notification as read.

Copilot uses AI. Check for mistakes.

threadID, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
}
markReadOptions = github.Timestamp{
Time: lastReadTime,
return mcp.NewToolResultError(err.Error()), nil
}
}

resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
if err != nil {
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
resp, err := client.Activity.MarkThreadRead(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
}

return mcp.NewToolResultText("All notifications marked as read"), nil
}
}

// GetNotificationThread creates a tool to get a specific notification thread.
func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("get_notification_thread",
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")),
mcp.WithString("threadID",
mcp.Required(),
mcp.Description("The ID of the notification thread"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

threadID, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

thread, resp, err := client.Activity.GetThread(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to get notification thread: %w", err)
}
defer func() { _ = resp.Body.Close() }()
return mcp.NewToolResultText("Notification marked as read"), nil

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
case "mark_done":
threadIDStr, err := requiredParam[string](request, "threadID")
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil
}

r, err := json.Marshal(thread)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
}
}
threadID, err := strconv.ParseInt(threadIDStr, 10, 64)
if err != nil {
return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil
}

// markNotificationDone creates a tool to mark a notification as done.
func MarkNotificationDone(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("mark_notification_done",
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_DONE_DESCRIPTION", "Mark a notification as done")),
mcp.WithString("threadID",
mcp.Required(),
mcp.Description("The ID of the notification thread"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := getclient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}
resp, err := client.Activity.MarkThreadDone(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to mark notification as done: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
}

threadIDStr, err := requiredParam[string](request, "threadID")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText("Notification marked as done"), nil

threadID, err := strconv.ParseInt(threadIDStr, 10, 64)
if err != nil {
return mcp.NewToolResultError("Invalid threadID: must be a numeric value"), nil
}
case "mark_all_read":
lastReadAt, err := OptionalStringParam(request, "lastReadAt")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

resp, err := client.Activity.MarkThreadDone(ctx, threadID)
if err != nil {
return nil, fmt.Errorf("failed to mark notification as done: %w", err)
}
defer func() { _ = resp.Body.Close() }()
var markReadOptions github.Timestamp
if lastReadAt != "" {
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
}
markReadOptions = github.Timestamp{
Time: lastReadTime,
}
}

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as done: %s", string(body))), nil
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
}

return mcp.NewToolResultText("All notifications marked as read"), nil

return mcp.NewToolResultText("Notification marked as done"), nil
default:
return mcp.NewToolResultError("Invalid action: must be 'mark_read', 'mark_all_read', or 'mark_done'"), nil
}
}
}
8 changes: 2 additions & 6 deletions pkg/github/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,10 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,

notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools").
AddReadTools(

toolsets.NewServerTool(MarkNotificationRead(getClient, t)),
toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)),
toolsets.NewServerTool(MarkNotificationDone(getClient, t)),
toolsets.NewServerTool(GetNotifications(getClient, t)),
).
AddWriteTools(
toolsets.NewServerTool(GetNotifications(getClient, t)),
toolsets.NewServerTool(GetNotificationThread(getClient, t)),
toolsets.NewServerTool(ManageNotifications(getClient, t)),
)

// Keep experiments alive so the system doesn't error out when it's always enabled
Expand Down