Skip to content

Commit

Permalink
Merge ea5b735 into d498fe3
Browse files Browse the repository at this point in the history
  • Loading branch information
wphyojpl authored Oct 21, 2022
2 parents d498fe3 + ea5b735 commit 2f1a1ff
Show file tree
Hide file tree
Showing 15 changed files with 498 additions and 1 deletion.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from abc import ABC, abstractmethod


class UDSAuthorizorAbstract(ABC):
@abstractmethod
def authorize(self, username, resource, action) -> bool:
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
import os

from cumulus_lambda_functions.authorization.uds_authorizer_abstract import UDSAuthorizorAbstract
from cumulus_lambda_functions.lib.aws.aws_cognito import AwsCognito
from cumulus_lambda_functions.lib.aws.es_abstract import ESAbstract
from cumulus_lambda_functions.lib.aws.es_factory import ESFactory

LOGGER = logging.getLogger(__name__)


class UDSAuthorizorEsIdentityPool(UDSAuthorizorAbstract):

def __init__(self, user_pool_id: str) -> None:
super().__init__()
es_url = os.getenv('ES_URL') # TODO validation
authorization_index = os.getenv('AUTHORIZATION_URL') # LDAP_Group_Permission
es_port = int(os.getenv('ES_PORT', '443'))
self.__cognito = AwsCognito(user_pool_id)
self.__es: ESAbstract = ESFactory().get_instance('AWS',
index=authorization_index,
base_url=es_url,
port=es_port)

def authorize(self, username, resource, action) -> bool:
belonged_groups = set(self.__cognito.get_groups(username))
authorized_groups = self.__es.query({
'query': {
'match_all': {} # TODO
}
})
LOGGER.debug(f'belonged_groups for {username}: {belonged_groups}')
authorized_groups = set([k['_source']['group_name'] for k in authorized_groups['hits']['hits']])
LOGGER.debug(f'authorized_groups for {resource}-{action}: {authorized_groups}')
if any([k in authorized_groups for k in belonged_groups]):
LOGGER.debug(f'{username} is authorized for {resource}-{action}')
return True
LOGGER.debug(f'{username} is NOT authorized for {resource}-{action}')
return False
20 changes: 20 additions & 0 deletions cumulus_lambda_functions/lib/aws/aws_cognito.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from cumulus_lambda_functions.lib.aws.aws_cred import AwsCred


class AwsCognito(AwsCred):
def __init__(self, user_pool_id: str):
super().__init__()
self.__cognito = self.get_client('cognito-idp')
self.__user_pool_id = user_pool_id

def get_groups(self, username: str):
response = self.__cognito.admin_list_groups_for_user(
Username=username,
UserPoolId=self.__user_pool_id,
Limit=60,
# NextToken='string'
)
if response is None or 'Groups' not in response:
return []
belonged_groups = [k['GroupName'] for k in response['Groups']]
return belonged_groups
16 changes: 16 additions & 0 deletions cumulus_lambda_functions/lib/aws/aws_cred.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ def __init__(self):
else:
LOGGER.debug('using default session as there is no aws_access_key_id')

@property
def region(self):
return self.__region

@region.setter
def region(self, val):
"""
:param val:
:return: None
"""
self.__region = val
return

@property
def boto3_session(self):
return self.__boto3_session
Expand All @@ -56,6 +69,9 @@ def boto3_session(self, val):
self.__boto3_session = val
return

def get_session(self):
return boto3.Session(**self.boto3_session)

def get_resource(self, service_name: str):
return boto3.Session(**self.boto3_session).resource(service_name)

Expand Down
59 changes: 59 additions & 0 deletions cumulus_lambda_functions/lib/aws/es_abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from abc import ABC, abstractmethod
from typing import Any, Union, Callable

DEFAULT_TYPE = '_doc'


class ESAbstract(ABC):
@abstractmethod
def create_index(self, index_name, index_body):
return

@abstractmethod
def has_index(self, index_name):
return

@abstractmethod
def create_alias(self, index_name, alias_name):
return

@abstractmethod
def delete_index(self, index_name):
return

@abstractmethod
def index_many(self, docs=None, doc_ids=None, doc_dict=None, index=None):
return

@abstractmethod
def index_one(self, doc, doc_id, index=None):
return

@abstractmethod
def update_many(self, docs=None, doc_ids=None, doc_dict=None, index=None):
return

@abstractmethod
def update_one(self, doc, doc_id, index=None):
return

@staticmethod
@abstractmethod
def get_result_size(result):
return

@abstractmethod
def query_with_scroll(self, dsl, querying_index=None):
return

@abstractmethod
def query(self, dsl, querying_index=None):
return

@abstractmethod
def query_pages(self, dsl, querying_index=None):
return

@abstractmethod
def query_by_id(self, doc_id):
return
16 changes: 16 additions & 0 deletions cumulus_lambda_functions/lib/aws/es_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from cumulus_lambda_functions.lib.aws.factory_abstract import FactoryAbstract


class ESFactory(FactoryAbstract):
NO_AUTH = 'NO_AUTH'
AWS = 'AWS'

def get_instance(self, class_type, **kwargs):
ct = class_type.upper()
if ct == self.NO_AUTH:
from cumulus_lambda_functions.lib.aws.es_middleware import ESMiddleware
return ESMiddleware(kwargs['index'], kwargs['base_url'], port=kwargs['port'])
if ct == self.AWS:
from cumulus_lambda_functions.lib.aws.es_middleware_aws import EsMiddlewareAws
return EsMiddlewareAws(kwargs['index'], kwargs['base_url'], port=kwargs['port'])
raise ModuleNotFoundError(f'cannot find ES class for {ct}')
216 changes: 216 additions & 0 deletions cumulus_lambda_functions/lib/aws/es_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import json
import logging

from elasticsearch import Elasticsearch

from cumulus_lambda_functions.lib.aws.es_abstract import ESAbstract, DEFAULT_TYPE

LOGGER = logging.getLogger(__name__)


class ESMiddleware(ESAbstract):

def __init__(self, index, base_url, port=443) -> None:
if any([k is None for k in [index, base_url]]):
raise ValueError(f'index or base_url is None')
self.__index = index
base_url = base_url.replace('https://', '') # hide https
self._engine = Elasticsearch(hosts=[{'host': base_url, 'port': port}])

def __validate_index(self, index):
if index is not None:
return index
if self.__index is not None:
return self.__index
raise ValueError('index value is NULL')

def __get_doc_dict(self, docs=None, doc_ids=None, doc_dict=None):
if doc_dict is None and (docs is None and doc_ids is None):
raise ValueError('must provide either doc dictionary or doc list & id list')
if doc_dict is None: # it comes as a list
if len(docs) != len(doc_ids):
raise ValueError('length of doc and id is different')
doc_dict = {k: v for k, v in zip(doc_ids, docs)}
pass
return doc_dict

def __check_errors_for_bulk(self, index_result):
if 'errors' not in index_result or index_result['errors'] is False:
return
err_list = [[{'id': v['_id'], 'error': v['error']} for _, v in each.items() if 'error' in v] for each in
index_result['items']]
if len(err_list) < 1:
return
LOGGER.exception('failed to add some items. details: {}'.format(err_list))
return err_list

def create_index(self, index_name, index_body):
result = self._engine.indices.create(index=index_name, body=index_body, include_type_name=False)
if 'acknowledged' not in result:
return result
return result['acknowledged']

def has_index(self, index_name):
result = self._engine.indices.exists(index=index_name)
return result

def create_alias(self, index_name, alias_name):
result = self._engine.indices.put_alias(index_name, alias_name)
if 'acknowledged' not in result:
return result
return result['acknowledged']

def delete_index(self, index_name):
result = self._engine.indices.delete(index_name)
if 'acknowledged' not in result:
return result
return result['acknowledged']

def index_many(self, docs=None, doc_ids=None, doc_dict=None, index=None):
doc_dict = self.__get_doc_dict(docs, doc_ids, doc_dict)
body = []
for k, v in doc_dict.items():
body.append({'index': {'_index': index, '_id': k, 'retry_on_conflict': 3}})
body.append(v)
pass
index = self.__validate_index(index)
try:
index_result = self._engine.bulk(index=index,
body=body, doc_type=DEFAULT_TYPE)
LOGGER.info('indexed. result: {}'.format(index_result))
return self.__check_errors_for_bulk(index_result)
except:
LOGGER.exception('cannot add indices with ids: {} for index: {}'.format(list(doc_dict.keys()), index))
return doc_dict
return

def index_one(self, doc, doc_id, index=None):
index = self.__validate_index(index)
try:
index_result = self._engine.index(index=index,
body=doc, doc_type=DEFAULT_TYPE, id=doc_id)
LOGGER.info('indexed. result: {}'.format(index_result))
pass
except:
LOGGER.exception('cannot add a new index with id: {} for index: {}'.format(doc_id, index))
return None
return self

def update_many(self, docs=None, doc_ids=None, doc_dict=None, index=None):
doc_dict = self.__get_doc_dict(docs, doc_ids, doc_dict)
body = []
for k, v in doc_dict.items():
body.append({'update': {'_index': index, '_id': k, 'retry_on_conflict': 3}})
body.append({'doc': v, 'doc_as_upsert': True})
pass
index = self.__validate_index(index)
try:
index_result = self._engine.bulk(index=index,
body=body, doc_type=DEFAULT_TYPE)
LOGGER.info('indexed. result: {}'.format(index_result))
return self.__check_errors_for_bulk(index_result)
except:
LOGGER.exception('cannot update indices with ids: {} for index: {}'.format(list(doc_dict.keys()),
index))
return doc_dict
return

def update_one(self, doc, doc_id, index=None):
update_body = {
'doc': doc,
'doc_as_upsert': True
}
index = self.__validate_index(index)
try:
update_result = self._engine.update(index=index,
id=doc_id, body=update_body, doc_type=DEFAULT_TYPE)
LOGGER.info('updated. result: {}'.format(update_result))
pass
except:
LOGGER.exception('cannot update id: {} for index: {}'.format(doc_id, index))
return None
return self

@staticmethod
def get_result_size(result):
if isinstance(result['hits']['total'], dict): # fix for different datatype in elastic-search result
return result['hits']['total']['value']
else:
return result['hits']['total']

def query_with_scroll(self, dsl, querying_index=None):
scroll_timeout = '30s'
index = self.__validate_index(querying_index)
dsl['size'] = 10000 # replacing with the maximum size to minimize number of scrolls
params = {
'index': index,
'size': 10000,
'scroll': scroll_timeout,
'body': dsl,
}
first_batch = self._engine.search(**params)
total_size = self.get_result_size(first_batch)
current_size = len(first_batch['hits']['hits'])
scroll_id = first_batch['_scroll_id']
while current_size < total_size: # need to scroll
scrolled_result = self._engine.scroll(scroll_id=scroll_id, scroll=scroll_timeout)
scroll_id = scrolled_result['_scroll_id']
scrolled_result_size = len(scrolled_result['hits']['hits'])
if scrolled_result_size == 0:
break
else:
current_size += scrolled_result_size
first_batch['hits']['hits'].extend(scrolled_result['hits']['hits'])
return first_batch

def query(self, dsl, querying_index=None):
index = self.__validate_index(querying_index)
return self._engine.search(body=dsl, index=index)

def __is_querying_next_page(self, targeted_size: int, current_size: int, total_size: int):
if targeted_size < 0:
return current_size > 0
return current_size > 0 and total_size < targeted_size

def query_pages(self, dsl, querying_index=None):
if 'sort' not in dsl:
raise ValueError('missing `sort` in DSL. Make sure sorting is unique')
index = self.__validate_index(querying_index)
targeted_size = dsl['sort'] if 'size' in dsl else -1
dsl['size'] = 10000 # replacing with the maximum size to minimize number of scrolls
params = {
'index': index,
'size': 10000,
'body': dsl,
}
LOGGER.debug(f'dsl: {dsl}')
result_list = []
total_size = 0
result_batch = self._engine.search(**params)
result_list.extend(result_batch['hits']['hits'])
current_size = len(result_batch['hits']['hits'])
total_size += current_size
while self.__is_querying_next_page(targeted_size, current_size, total_size):
params['body']['search_after'] = result_batch['hits']['hits'][-1]['sort']
result_batch = self._engine.search(**params)
result_list.extend(result_batch['hits']['hits'])
current_size = len(result_batch['hits']['hits'])
total_size += current_size
return {
'hits': {
'hits': result_list,
'total': total_size,
}
}

def query_by_id(self, doc_id):
index = self.__validate_index(None)
dsl = {
'query': {
'term': {'_id': doc_id}
}
}
result = self._engine.search(index=index, body=dsl)
if self.get_result_size(result) < 1:
return None
return result['hits']['hits'][0]
Loading

0 comments on commit 2f1a1ff

Please sign in to comment.