From 58cb3f672324b715aeae04ac90368a33e8b045fa Mon Sep 17 00:00:00 2001 From: Jake Van Vorhis <83739412+jakedoublev@users.noreply.github.com> Date: Tue, 3 Dec 2024 07:28:09 -0800 Subject: [PATCH] fix(authz): handle pagination in authz service (#1797) ### Proposed Changes * Handle pagination limit/offset in `authorization.GetEntitlements` * If quantity of attributes or subject mappings exceeds default "list" API limit as configured, we need to retrieve all to make accurate entitlement decision ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions --- service/authorization/authorization.go | 59 +++++++-- service/authorization/authorization_test.go | 135 ++++++++++++++++++++ 2 files changed, 184 insertions(+), 10 deletions(-) diff --git a/service/authorization/authorization.go b/service/authorization/authorization.go index 711293444..49def91d0 100644 --- a/service/authorization/authorization.go +++ b/service/authorization/authorization.go @@ -16,6 +16,7 @@ import ( "github.com/open-policy-agent/opa/rego" "github.com/opentdf/platform/protocol/go/authorization" "github.com/opentdf/platform/protocol/go/authorization/authorizationconnect" + "github.com/opentdf/platform/protocol/go/common" "github.com/opentdf/platform/protocol/go/entityresolution" "github.com/opentdf/platform/protocol/go/policy" attr "github.com/opentdf/platform/protocol/go/policy/attributes" @@ -444,22 +445,60 @@ func makeScopeMap(scope *authorization.ResourceAttribute) map[string]bool { func (as *AuthorizationService) GetEntitlements(ctx context.Context, req *connect.Request[authorization.GetEntitlementsRequest]) (*connect.Response[authorization.GetEntitlementsResponse], error) { as.logger.DebugContext(ctx, "getting entitlements") - attrsRes, err := as.sdk.Attributes.ListAttributes(ctx, &attr.ListAttributesRequest{}) - if err != nil { - as.logger.ErrorContext(ctx, "failed to list attributes", slog.String("error", err.Error())) - return nil, connect.NewError(connect.CodeInternal, errors.New("failed to list attributes")) + + var nextOffset int32 + attrsList := make([]*policy.Attribute, 0) + subjectMappingsList := make([]*policy.SubjectMapping, 0) + + // If quantity of attributes exceeds maximum list pagination, all are needed to determine entitlements + for { + listed, err := as.sdk.Attributes.ListAttributes(ctx, &attr.ListAttributesRequest{ + State: common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE, + Pagination: &policy.PageRequest{ + Offset: nextOffset, + }, + }) + if err != nil { + as.logger.ErrorContext(ctx, "failed to list attributes", slog.String("error", err.Error())) + return nil, connect.NewError(connect.CodeInternal, errors.New("failed to list attributes")) + } + + nextOffset = listed.GetPagination().GetNextOffset() + attrsList = append(attrsList, listed.GetAttributes()...) + + // offset becomes zero when list is exhausted + if nextOffset <= 0 { + break + } } - subMapsRes, err := as.sdk.SubjectMapping.ListSubjectMappings(ctx, &subjectmapping.ListSubjectMappingsRequest{}) - if err != nil { - as.logger.ErrorContext(ctx, "failed to list subject mappings", slog.String("error", err.Error())) - return nil, connect.NewError(connect.CodeInternal, errors.New("failed to list subject mappings")) + + // If quantity of subject mappings exceeds maximum list pagination, all are needed to determine entitlements + nextOffset = 0 + for { + listed, err := as.sdk.SubjectMapping.ListSubjectMappings(ctx, &subjectmapping.ListSubjectMappingsRequest{ + Pagination: &policy.PageRequest{ + Offset: nextOffset, + }, + }) + if err != nil { + as.logger.ErrorContext(ctx, "failed to list subject mappings", slog.String("error", err.Error())) + return nil, connect.NewError(connect.CodeInternal, errors.New("failed to list subject mappings")) + } + + nextOffset = listed.GetPagination().GetNextOffset() + subjectMappingsList = append(subjectMappingsList, listed.GetSubjectMappings()...) + + // offset becomes zero when list is exhausted + if nextOffset <= 0 { + break + } } // create a lookup map of attribute value FQNs (based on request scope) scopeMap := makeScopeMap(req.Msg.GetScope()) // create a lookup map of subject mappings by attribute value ID - subMapsByVal := makeSubMapsByValLookup(subMapsRes.GetSubjectMappings()) + subMapsByVal := makeSubMapsByValLookup(subjectMappingsList) // create a lookup map of attribute values by FQN (for rego query) - fqnAttrVals := makeValsByFqnsLookup(attrsRes.GetAttributes(), subMapsByVal, scopeMap) + fqnAttrVals := makeValsByFqnsLookup(attrsList, subMapsByVal, scopeMap) avf := &attr.GetAttributeValuesByFqnsResponse{ FqnAttributeValues: fqnAttrVals, } diff --git a/service/authorization/authorization_test.go b/service/authorization/authorization_test.go index e64f849b7..c24b3751b 100644 --- a/service/authorization/authorization_test.go +++ b/service/authorization/authorization_test.go @@ -57,6 +57,10 @@ type mySubjectMappingClient struct { sm.SubjectMappingServiceClient } +type paginatedMockSubjectMappingClient struct { + sm.SubjectMappingServiceClient +} + func (*mySubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { return &listSubjectMappings, nil } @@ -69,6 +73,52 @@ func (*myERSClient) ResolveEntities(_ context.Context, _ *entityresolution.Resol return &resolveEntitiesResp, nil } +var ( + smPaginationOffset = 3 + smListCallCount = 0 +) + +func (*paginatedMockSubjectMappingClient) ListSubjectMappings(_ context.Context, _ *sm.ListSubjectMappingsRequest, _ ...grpc.CallOption) (*sm.ListSubjectMappingsResponse, error) { + smListCallCount++ + // simulate paginated list and policy LIST behavior + if smPaginationOffset > 0 { + rsp := &sm.ListSubjectMappingsResponse{ + SubjectMappings: nil, + Pagination: &policy.PageResponse{ + NextOffset: int32(smPaginationOffset), + }, + } + smPaginationOffset = 0 + return rsp, nil + } + return &listSubjectMappings, nil +} + +type paginatedMockAttributesClient struct { + attr.AttributesServiceClient +} + +var ( + attrPaginationOffset = 3 + attrListCallCount = 0 +) + +func (*paginatedMockAttributesClient) ListAttributes(_ context.Context, _ *attr.ListAttributesRequest, _ ...grpc.CallOption) (*attr.ListAttributesResponse, error) { + attrListCallCount++ + // simulate paginated list and policy LIST behavior + if attrPaginationOffset > 0 { + rsp := &attr.ListAttributesResponse{ + Attributes: nil, + Pagination: &policy.PageResponse{ + NextOffset: int32(attrPaginationOffset), + }, + } + attrPaginationOffset = 0 + return rsp, nil + } + return &listAttributeResp, nil +} + func TestGetComprehensiveHierarchy(t *testing.T) { as := &AuthorizationService{ logger: logger.CreateTestLogger(), @@ -763,6 +813,91 @@ func Test_GetEntitlementsFqnCasing(t *testing.T) { assert.Equal(t, []string{"https://www.example.org/attr/foo/value/value1"}, resp.Msg.GetEntitlements()[0].GetAttributeValueFqns()) } +func Test_GetEntitlements_HandlesPagination(t *testing.T) { + logger := logger.CreateTestLogger() + + listAttributeResp = attr.ListAttributesResponse{} + attrDef := policy.Attribute{ + Name: mockAttrName, + Namespace: &policy.Namespace{ + Name: mockNamespace, + }, + Rule: policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF, + Values: []*policy.Value{ + { + Value: mockAttrValue1, + }, + { + Value: mockAttrValue2, + }, + }, + } + listAttributeResp.Attributes = []*policy.Attribute{&attrDef} + userRepresentation := map[string]interface{}{ + "A": "B", + "C": "D", + } + userStruct, _ := structpb.NewStruct(userRepresentation) + resolveEntitiesResp = entityresolution.ResolveEntitiesResponse{ + EntityRepresentations: []*entityresolution.EntityRepresentation{ + { + OriginalId: "e1", + AdditionalProps: []*structpb.Struct{ + userStruct, + }, + }, + }, + } + + ctxb := context.Background() + + rego := rego.New( + rego.Query("data.example.p"), + rego.Module("example.rego", + `package example + p = {"e1":["https://www.example.org/attr/foo/value/value1"]} { true }`, + )) + + // Run evaluation. + prepared, err := rego.PrepareForEval(ctxb) + require.NoError(t, err) + + as := AuthorizationService{ + logger: logger, sdk: &otdf.SDK{ + SubjectMapping: &paginatedMockSubjectMappingClient{}, + Attributes: &paginatedMockAttributesClient{}, + EntityResoution: &myERSClient{}, + }, + eval: prepared, + } + + req := connect.Request[authorization.GetEntitlementsRequest]{ + Msg: &authorization.GetEntitlementsRequest{ + Entities: []*authorization.Entity{{Id: "e1", EntityType: &authorization.Entity_ClientId{ClientId: "testclient"}, Category: authorization.Entity_CATEGORY_ENVIRONMENT}}, + // Using mixed case here + Scope: &authorization.ResourceAttribute{AttributeValueFqns: []string{"https://www.example.org/attr/foo/value/VaLuE1"}}, + }, + } + + for fqn := range makeScopeMap(req.Msg.GetScope()) { + assert.Equal(t, fqn, strings.ToLower(fqn)) + } + + resp, err := as.GetEntitlements(ctxb, &req) + + require.NoError(t, err) + assert.NotNil(t, resp) + assert.Len(t, resp.Msg.GetEntitlements(), 1) + assert.Equal(t, "e1", resp.Msg.GetEntitlements()[0].GetEntityId()) + assert.Equal(t, []string{"https://www.example.org/attr/foo/value/value1"}, resp.Msg.GetEntitlements()[0].GetAttributeValueFqns()) + + // paginated successfully + assert.Equal(t, 2, smListCallCount) + assert.Zero(t, smPaginationOffset) + assert.Equal(t, 2, attrListCallCount) + assert.Zero(t, attrPaginationOffset) +} + func Test_GetEntitlementsWithComprehensiveHierarchy(t *testing.T) { logger := logger.CreateTestLogger() attrDef := policy.Attribute{