diff --git a/README.md b/README.md index eacaef24..3908cf64 100644 --- a/README.md +++ b/README.md @@ -436,6 +436,7 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `state`: New state ('open' or 'closed') (string, optional) - `base`: New base branch name (string, optional) - `maintainer_can_modify`: Allow maintainer edits (boolean, optional) + - `reviewers`: GitHub usernames to request reviews from (string[], optional) ### Repositories diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 9c8fca17..88baf15d 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -111,6 +111,12 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu mcp.WithBoolean("maintainer_can_modify", mcp.Description("Allow maintainer edits"), ), + mcp.WithArray("reviewers", + mcp.Description("GitHub usernames to request reviews from"), + mcp.Items(map[string]interface{}{ + "type": "string", + }), + ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := requiredParam[string](request, "owner") @@ -165,26 +171,101 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu updateNeeded = true } - if !updateNeeded { - return mcp.NewToolResultError("No update parameters provided."), nil + // Handle reviewers separately + var reviewers []string + if reviewersArr, ok := request.Params.Arguments["reviewers"].([]interface{}); ok && len(reviewersArr) > 0 { + for _, reviewer := range reviewersArr { + if reviewerStr, ok := reviewer.(string); ok { + reviewers = append(reviewers, reviewerStr) + } + } } + // Create the GitHub client client, err := getClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } - pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) - if err != nil { - return nil, fmt.Errorf("failed to update pull request: %w", err) + + var pr *github.PullRequest + var resp *http.Response + + // First, update the PR if needed + if updateNeeded { + var ghResp *github.Response + pr, ghResp, err = client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) + if err != nil { + return nil, fmt.Errorf("failed to update pull request: %w", err) + } + resp = ghResp.Response + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + } + } else { + // If no update needed, just get the current PR + var ghResp *github.Response + pr, ghResp, err = client.PullRequests.Get(ctx, owner, repo, pullNumber) + if err != nil { + return nil, fmt.Errorf("failed to get pull request: %w", err) + } + resp = ghResp.Response + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil + } } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + // Add reviewers if specified + if len(reviewers) > 0 { + reviewersRequest := github.ReviewersRequest{ + Reviewers: reviewers, + } + + // Use the direct result of RequestReviewers which includes the requested reviewers + updatedPR, resp, err := client.PullRequests.RequestReviewers(ctx, owner, repo, pullNumber, reviewersRequest) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, fmt.Errorf("failed to request reviewers: %w", err) } - return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(body))), nil + } + + // Use the updated PR with reviewers + pr = updatedPR + } + + // If no updates and no reviewers, return error + if !updateNeeded && len(reviewers) == 0 { + return mcp.NewToolResultError("No update parameters provided"), nil } r, err := json.Marshal(pr) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index bb372624..3a064a39 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -141,6 +141,7 @@ func Test_UpdatePullRequest(t *testing.T) { assert.Contains(t, tool.InputSchema.Properties, "state") assert.Contains(t, tool.InputSchema.Properties, "base") assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify") + assert.Contains(t, tool.InputSchema.Properties, "reviewers") assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"}) // Setup mock PR for success case @@ -162,6 +163,23 @@ func Test_UpdatePullRequest(t *testing.T) { State: github.Ptr("closed"), // State updated } + // Mock PR for when there are no updates but we still need a response + mockNoUpdatePR := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + } + + mockPRWithReviewers := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + RequestedReviewers: []*github.User{ + {Login: github.Ptr("reviewer1")}, + {Login: github.Ptr("reviewer2")}, + }, + } + tests := []struct { name string mockedClient *http.Client @@ -220,8 +238,40 @@ func Test_UpdatePullRequest(t *testing.T) { expectedPR: mockClosedPR, }, { - name: "no update parameters provided", - mockedClient: mock.NewMockedHTTPClient(), // No API call expected + name: "successful PR update with reviewers", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Test PR"), + State: github.Ptr("open"), + }, + ), + // Mock for RequestReviewers call, returning the PR with reviewers + mock.WithRequestMatch( + mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, + mockPRWithReviewers, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "reviewers": []interface{}{"reviewer1", "reviewer2"}, + }, + expectError: false, + expectedPR: mockPRWithReviewers, + }, + { + name: "no update parameters provided", + mockedClient: mock.NewMockedHTTPClient( + // Mock a response for the GET PR request in case of no updates + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockNoUpdatePR, + ), + ), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -251,6 +301,32 @@ func Test_UpdatePullRequest(t *testing.T) { expectError: true, expectedErrMsg: "failed to update pull request", }, + { + name: "request reviewers fails", + mockedClient: mock.NewMockedHTTPClient( + // First it gets the PR (no fields to update) + mock.WithRequestMatch( + mock.GetReposPullsByOwnerByRepoByPullNumber, + mockNoUpdatePR, + ), + // Then reviewer request fails + mock.WithRequestMatchHandler( + mock.PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid reviewers"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "reviewers": []interface{}{"invalid-user"}, + }, + expectError: true, + expectedErrMsg: "failed to request reviewers", + }, } for _, tc := range tests { @@ -304,6 +380,26 @@ func Test_UpdatePullRequest(t *testing.T) { if tc.expectedPR.MaintainerCanModify != nil { assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify) } + + // Check reviewers if they exist in the expected PR + if tc.expectedPR.RequestedReviewers != nil && len(tc.expectedPR.RequestedReviewers) > 0 { + assert.NotNil(t, returnedPR.RequestedReviewers) + assert.Equal(t, len(tc.expectedPR.RequestedReviewers), len(returnedPR.RequestedReviewers)) + + // Create maps of reviewer logins for easy comparison + expectedReviewers := make(map[string]bool) + for _, reviewer := range tc.expectedPR.RequestedReviewers { + expectedReviewers[*reviewer.Login] = true + } + + actualReviewers := make(map[string]bool) + for _, reviewer := range returnedPR.RequestedReviewers { + actualReviewers[*reviewer.Login] = true + } + + // Compare the maps + assert.Equal(t, expectedReviewers, actualReviewers) + } }) } }