Skip to content

Commit

Permalink
feat: preparing auth interfaces for partner auth plugins (#14)
Browse files Browse the repository at this point in the history
* splitting Auth interface from Cognito implementation

* documentation update

* removed layers, standardized functions

* deleting userpool, catch user not found

* adding missing iam policies for UserPool deletion

* chore: self mutation

Signed-off-by: github-actions <[email protected]>

* renaming auth interface file to auth-interface

* defining powertoos layer once on CognitoAuth

* removing user utils and jsonpickle

---------

Signed-off-by: github-actions <[email protected]>
Co-authored-by: Humberto Somensi <[email protected]>
Co-authored-by: github-actions <[email protected]>
  • Loading branch information
3 people authored Mar 20, 2024
1 parent 2463b9a commit 1bddfb7
Show file tree
Hide file tree
Showing 25 changed files with 366 additions and 464 deletions.
220 changes: 78 additions & 142 deletions API.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,26 @@
import boto3
import os
import uuid
import cognito.user_management_util as user_management_util
import user_management_util as user_management_util
from aws_lambda_powertools import Logger
from abstract_classes.identity_provider_abstract_class import IdentityProviderAbstractClass

logger = Logger()
cognito = boto3.client('cognito-idp')
region = os.environ['AWS_REGION']

class CognitoIdentityProviderManagement():
def delete_control_plane_idp(self, userPoolId):
response = cognito.describe_user_pool(
UserPoolId=userPoolId
)
domain = response['UserPool']['Domain']

cognito.delete_user_pool_domain(
UserPoolId=userPoolId,
Domain=domain
)
cognito.delete_user_pool(UserPoolId=userPoolId)

class CognitoIdentityProviderManagement(IdentityProviderAbstractClass):
def create_control_plane_idp(self, event):
idp_response = {}
idp_response['idp'] = {}
Expand All @@ -24,27 +35,17 @@ def create_control_plane_idp(self, event):

user_pool_response = self.__create_user_pool(
'SaaSControlPlaneUserPool', control_plane_callback_url)
logger.info(user_pool_response)
user_pool_id = user_pool_response['UserPool']['Id']
logger.info(user_pool_id)

app_client_response = self.__create_user_pool_client(
user_pool_id, control_plane_callback_url)
app_client_response = self.__create_user_pool_client(user_pool_id, control_plane_callback_url)
logger.info(app_client_response)
app_client_id = app_client_response['UserPoolClient']['ClientId']
user_pool_domain = 'saascontrolplane'+uuid.uuid1().hex
user_pool_domain_response = self.__create_user_pool_domain(
user_pool_id, user_pool_domain)

tenant_user_group_response = user_management_util.create_user_group(
user_pool_id, user_details['userRole'])

create_tenant_admin_response = user_management_util.create_user(
user_pool_id, user_details)

add_tenant_admin_to_group_response = user_management_util.add_user_to_group(
user_pool_id, user_details['userName'], tenant_user_group_response['Group']['GroupName'])

region = os.environ['AWS_REGION']
self.__create_user_pool_domain(user_pool_id, user_pool_domain)
tenant_user_group_response = user_management_util.create_user_group(user_pool_id, user_details['userRole'])
user_management_util.create_user(user_pool_id, user_details)
user_management_util.add_user_to_group(user_pool_id, user_details['userName'], tenant_user_group_response['Group']['GroupName'])

idp_response['idp']['name'] = 'Cognito'
idp_response['idp']['userPoolId'] = user_pool_id
Expand All @@ -56,8 +57,8 @@ def create_control_plane_idp(self, event):

def __create_user_pool(self, user_pool_name, control_plane_site_url):
email_message = ''.join(["Login into control plane UI at ",
control_plane_site_url,
" with username {username} and temporary password {####}"])
control_plane_site_url,
" with username {username} and temporary password {####}"])
email_subject = ''.join(
["Your temporary password for control plane UI"])
response = cognito.create_user_pool(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@
# SPDX-License-Identifier: Apache-2.0

import json
import os
import idp_object_factory

from cognito_identity_provider_management import CognitoIdentityProviderManagement
from crhelper import CfnResource
helper = CfnResource()

try:
idp_name = os.environ['IDP_NAME']
idp_mgmt_service = idp_object_factory.get_idp_mgmt_object(idp_name)
except Exception as e:
helper.init_failure(e)
idp_mgmt_service = CognitoIdentityProviderManagement()


@helper.create
Expand Down Expand Up @@ -43,14 +36,17 @@ def do_action(event, _):
helper.Data['ClientId'] = client_id
helper.Data['WellKnownEndpointUrl'] = well_known_endpoint

return idpDetails['idp']['userPoolId']
except Exception as e:
raise e


@helper.delete
def do_nothing(_, __):
pass

def do_delete(event, _):
try:
userPoolId = event['PhysicalResourceId']
idp_mgmt_service.delete_control_plane_idp(userPoolId)
except Exception as e:
raise e

def handler(event, context):
helper(event, context)
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def add_user_to_group(user_pool_id, user_name, group_name):
def user_group_exists(user_pool_id, group_name):
try:
response=cognito.get_group(
UserPoolId=user_pool_id,
GroupName=group_name)
UserPoolId=user_pool_id,
GroupName=group_name)
return True
except Exception as e:
return False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from abstract_classes.idp_authorizer_abstract_class import IdpAuthorizerAbstractClass
import json
import boto3
import time
Expand All @@ -15,7 +14,7 @@
region = boto3.session.Session().region_name


class CognitoAuthorizer(IdpAuthorizerAbstractClass):
class CognitoAuthorizer():
def validateJWT(self, event):

input_details = event
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,55 @@
import os
import re
import json
from jose import jwt
import idp_object_factory
from cognito_authorizer import CognitoAuthorizer
from aws_lambda_powertools import Logger

logger = Logger()

sys_admin_role_name = os.environ['SYS_ADMIN_ROLE_NAME']
idp_name = os.environ['IDP_NAME']
idp_details=json.loads(os.environ['IDP_DETAILS'])
idp_authorizer_service = idp_object_factory.get_idp_authorizer_object(idp_name)
idp_authorizer_service = CognitoAuthorizer()

def lambda_handler(event, context):
input_details={}
input_details['idpDetails'] = idp_details
#get JWT token after Bearer from authorization
token = event['authorizationToken'].split(" ")
if (token[0] != 'Bearer'):
raise Exception('Authorization header should have a format Bearer <JWT> Token')
jwt_bearer_token = token[1]
logger.info("Method ARN: " + event['methodArn'])
input_details['jwtToken']=jwt_bearer_token

response = idp_authorizer_service.validateJWT(input_details)

if (response == False):
input_details={}
input_details['idpDetails'] = idp_details
#get JWT token after Bearer from authorization
token = event['authorizationToken'].split(" ")
if (token[0] != 'Bearer'):
raise Exception('Authorization header should have a format Bearer <JWT> Token')
jwt_bearer_token = token[1]
logger.info("Method ARN: " + event['methodArn'])

input_details['jwtToken']=jwt_bearer_token

response = idp_authorizer_service.validateJWT(input_details)

if (response == False):
logger.error('Unauthorized')
raise Exception('Unauthorized')
else:
else:
logger.info(response)
principal_id = response["sub"]
user_name = response["cognito:username"]
user_role = response["custom:userRole"]

tmp = event['methodArn'].split(':')
api_gateway_arn_tmp = tmp[5].split('/')
aws_account_id = tmp[4]
tmp = event['methodArn'].split(':')
api_gateway_arn_tmp = tmp[5].split('/')
aws_account_id = tmp[4]

policy = AuthPolicy(principal_id, aws_account_id)
policy.restApiId = api_gateway_arn_tmp[0]
policy.region = tmp[3]
policy.stage = api_gateway_arn_tmp[1]

if (user_role != sys_admin_role_name):
logger.error('Unauthorized')
return Exception('Unauthorized')

policy.allowAllMethods()

policy = AuthPolicy(principal_id, aws_account_id)
policy.restApiId = api_gateway_arn_tmp[0]
policy.region = tmp[3]
policy.stage = api_gateway_arn_tmp[1]

if (user_role != sys_admin_role_name):
logger.error('Unauthorized')
return Exception('Unauthorized')

policy.allowAllMethods()

return policy.build()
return policy.build()

class HttpVerb:
GET = "GET"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

aws-lambda-powertools[all]==2.34.2
python-jose[cryptography]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import botocore
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.event_handler import (APIGatewayRestResolver,
CORSConfig)
CORSConfig)
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.event_handler.exceptions import (
BadRequestError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
# SPDX-License-Identifier: Apache-2.0

import boto3
import cognito.user_management_util as user_management_util
from abstract_classes.idp_user_management_abstract_class import IdpUserManagementAbstractClass

import json
import user_management_util as user_management_util

client = boto3.client('cognito-idp')

class CognitoUserManagementService(IdpUserManagementAbstractClass):
class CognitoUserManagementService():
def create_user(self, event):
user_details = event
user_pool_id = user_details['idpDetails']['idp']['userPoolId']
Expand All @@ -25,12 +24,12 @@ def create_user(self, event):
def get_users(self, event):
user_details = event
user_pool_id = user_details['idpDetails']['idp']['userPoolId']
users = []
users = UserInfoList()

response = client.list_users(
UserPoolId=user_pool_id
)

num_of_users = len(response['Users'])

if (num_of_users > 0):
Expand All @@ -47,28 +46,35 @@ def get_users(self, event):
user_info.modified = user["UserLastModifiedDate"]
user_info.status = user["UserStatus"]
user_info.user_name = user["Username"]
users.append(user_info)
users.add_user(user_info)

return users


def get_user(self, event):
user_details = event
user_pool_id = user_details['idpDetails']['idp']['userPoolId']
user_name = user_details['userName']
response = client.admin_get_user(
UserPoolId=user_pool_id,
Username=user_name
)

user_info = UserInfo()
user_info.user_name = response["Username"]
for attr in response["UserAttributes"]:
if(attr["Name"] == "custom:userRole"):
user_info.user_role = attr["Value"]
if(attr["Name"] == "email"):
user_info.email = attr["Value"]
return user_info
try:
user_details = event
user_pool_id = user_details['idpDetails']['idp']['userPoolId']
user_name = user_details['userName']
response = client.admin_get_user(
UserPoolId=user_pool_id,
Username=user_name
)

user_info = UserInfo()
user_info.user_name = response["Username"]
user_info.enabled = response["Enabled"]
user_info.created = response["UserCreateDate"]
user_info.modified = response["UserLastModifiedDate"]
user_info.status = response["UserStatus"]
for attr in response["UserAttributes"]:
if(attr["Name"] == "custom:userRole"):
user_info.user_role = attr["Value"]
if(attr["Name"] == "email"):
user_info.email = attr["Value"]
return user_info
except client.exceptions.UserNotFoundException as e:
return

def update_user(self, event):
user_details = event
Expand Down Expand Up @@ -122,17 +128,12 @@ def delete_user(self, event):
user_name = user_details['userName']

response = client.admin_delete_user(
UserPoolId=user_pool_id,
UserPoolId=user_pool_id,
Username=user_name
)

return response






class UserInfo:
def __init__(self, user_name=None, user_role=None,
email=None, status=None, enabled=None, created=None, modified=None):
Expand All @@ -142,4 +143,17 @@ def __init__(self, user_name=None, user_role=None,
self.status = status
self.enabled = enabled
self.created = created
self.modified = modified
self.modified = modified
def serialize(self):
return json.dumps(self.__dict__, default=str)

class UserInfoList:
def __init__(self):
self.users = []

def add_user(self, user):
self.users.append(user)

def serialize(self):
user_dicts = [user.__dict__ for user in self.users]
return json.dumps(user_dicts, default=str)
Loading

0 comments on commit 1bddfb7

Please sign in to comment.