-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
498 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
7 changes: 7 additions & 0 deletions
7
cumulus_lambda_functions/authorization/uds_authorizer_abstract.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class UDSAuthorizorAbstract(ABC): | ||
@abstractmethod | ||
def authorize(self, username, resource, action) -> bool: | ||
return False |
39 changes: 39 additions & 0 deletions
39
cumulus_lambda_functions/authorization/uds_authorizer_es_identity_pool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import logging | ||
import os | ||
|
||
from cumulus_lambda_functions.authorization.uds_authorizer_abstract import UDSAuthorizorAbstract | ||
from cumulus_lambda_functions.lib.aws.aws_cognito import AwsCognito | ||
from cumulus_lambda_functions.lib.aws.es_abstract import ESAbstract | ||
from cumulus_lambda_functions.lib.aws.es_factory import ESFactory | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class UDSAuthorizorEsIdentityPool(UDSAuthorizorAbstract): | ||
|
||
def __init__(self, user_pool_id: str) -> None: | ||
super().__init__() | ||
es_url = os.getenv('ES_URL') # TODO validation | ||
authorization_index = os.getenv('AUTHORIZATION_URL') # LDAP_Group_Permission | ||
es_port = int(os.getenv('ES_PORT', '443')) | ||
self.__cognito = AwsCognito(user_pool_id) | ||
self.__es: ESAbstract = ESFactory().get_instance('AWS', | ||
index=authorization_index, | ||
base_url=es_url, | ||
port=es_port) | ||
|
||
def authorize(self, username, resource, action) -> bool: | ||
belonged_groups = set(self.__cognito.get_groups(username)) | ||
authorized_groups = self.__es.query({ | ||
'query': { | ||
'match_all': {} # TODO | ||
} | ||
}) | ||
LOGGER.debug(f'belonged_groups for {username}: {belonged_groups}') | ||
authorized_groups = set([k['_source']['group_name'] for k in authorized_groups['hits']['hits']]) | ||
LOGGER.debug(f'authorized_groups for {resource}-{action}: {authorized_groups}') | ||
if any([k in authorized_groups for k in belonged_groups]): | ||
LOGGER.debug(f'{username} is authorized for {resource}-{action}') | ||
return True | ||
LOGGER.debug(f'{username} is NOT authorized for {resource}-{action}') | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from cumulus_lambda_functions.lib.aws.aws_cred import AwsCred | ||
|
||
|
||
class AwsCognito(AwsCred): | ||
def __init__(self, user_pool_id: str): | ||
super().__init__() | ||
self.__cognito = self.get_client('cognito-idp') | ||
self.__user_pool_id = user_pool_id | ||
|
||
def get_groups(self, username: str): | ||
response = self.__cognito.admin_list_groups_for_user( | ||
Username=username, | ||
UserPoolId=self.__user_pool_id, | ||
Limit=60, | ||
# NextToken='string' | ||
) | ||
if response is None or 'Groups' not in response: | ||
return [] | ||
belonged_groups = [k['GroupName'] for k in response['Groups']] | ||
return belonged_groups |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Union, Callable | ||
|
||
DEFAULT_TYPE = '_doc' | ||
|
||
|
||
class ESAbstract(ABC): | ||
@abstractmethod | ||
def create_index(self, index_name, index_body): | ||
return | ||
|
||
@abstractmethod | ||
def has_index(self, index_name): | ||
return | ||
|
||
@abstractmethod | ||
def create_alias(self, index_name, alias_name): | ||
return | ||
|
||
@abstractmethod | ||
def delete_index(self, index_name): | ||
return | ||
|
||
@abstractmethod | ||
def index_many(self, docs=None, doc_ids=None, doc_dict=None, index=None): | ||
return | ||
|
||
@abstractmethod | ||
def index_one(self, doc, doc_id, index=None): | ||
return | ||
|
||
@abstractmethod | ||
def update_many(self, docs=None, doc_ids=None, doc_dict=None, index=None): | ||
return | ||
|
||
@abstractmethod | ||
def update_one(self, doc, doc_id, index=None): | ||
return | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_result_size(result): | ||
return | ||
|
||
@abstractmethod | ||
def query_with_scroll(self, dsl, querying_index=None): | ||
return | ||
|
||
@abstractmethod | ||
def query(self, dsl, querying_index=None): | ||
return | ||
|
||
@abstractmethod | ||
def query_pages(self, dsl, querying_index=None): | ||
return | ||
|
||
@abstractmethod | ||
def query_by_id(self, doc_id): | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from cumulus_lambda_functions.lib.aws.factory_abstract import FactoryAbstract | ||
|
||
|
||
class ESFactory(FactoryAbstract): | ||
NO_AUTH = 'NO_AUTH' | ||
AWS = 'AWS' | ||
|
||
def get_instance(self, class_type, **kwargs): | ||
ct = class_type.upper() | ||
if ct == self.NO_AUTH: | ||
from cumulus_lambda_functions.lib.aws.es_middleware import ESMiddleware | ||
return ESMiddleware(kwargs['index'], kwargs['base_url'], port=kwargs['port']) | ||
if ct == self.AWS: | ||
from cumulus_lambda_functions.lib.aws.es_middleware_aws import EsMiddlewareAws | ||
return EsMiddlewareAws(kwargs['index'], kwargs['base_url'], port=kwargs['port']) | ||
raise ModuleNotFoundError(f'cannot find ES class for {ct}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import json | ||
import logging | ||
|
||
from elasticsearch import Elasticsearch | ||
|
||
from cumulus_lambda_functions.lib.aws.es_abstract import ESAbstract, DEFAULT_TYPE | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class ESMiddleware(ESAbstract): | ||
|
||
def __init__(self, index, base_url, port=443) -> None: | ||
if any([k is None for k in [index, base_url]]): | ||
raise ValueError(f'index or base_url is None') | ||
self.__index = index | ||
base_url = base_url.replace('https://', '') # hide https | ||
self._engine = Elasticsearch(hosts=[{'host': base_url, 'port': port}]) | ||
|
||
def __validate_index(self, index): | ||
if index is not None: | ||
return index | ||
if self.__index is not None: | ||
return self.__index | ||
raise ValueError('index value is NULL') | ||
|
||
def __get_doc_dict(self, docs=None, doc_ids=None, doc_dict=None): | ||
if doc_dict is None and (docs is None and doc_ids is None): | ||
raise ValueError('must provide either doc dictionary or doc list & id list') | ||
if doc_dict is None: # it comes as a list | ||
if len(docs) != len(doc_ids): | ||
raise ValueError('length of doc and id is different') | ||
doc_dict = {k: v for k, v in zip(doc_ids, docs)} | ||
pass | ||
return doc_dict | ||
|
||
def __check_errors_for_bulk(self, index_result): | ||
if 'errors' not in index_result or index_result['errors'] is False: | ||
return | ||
err_list = [[{'id': v['_id'], 'error': v['error']} for _, v in each.items() if 'error' in v] for each in | ||
index_result['items']] | ||
if len(err_list) < 1: | ||
return | ||
LOGGER.exception('failed to add some items. details: {}'.format(err_list)) | ||
return err_list | ||
|
||
def create_index(self, index_name, index_body): | ||
result = self._engine.indices.create(index=index_name, body=index_body, include_type_name=False) | ||
if 'acknowledged' not in result: | ||
return result | ||
return result['acknowledged'] | ||
|
||
def has_index(self, index_name): | ||
result = self._engine.indices.exists(index=index_name) | ||
return result | ||
|
||
def create_alias(self, index_name, alias_name): | ||
result = self._engine.indices.put_alias(index_name, alias_name) | ||
if 'acknowledged' not in result: | ||
return result | ||
return result['acknowledged'] | ||
|
||
def delete_index(self, index_name): | ||
result = self._engine.indices.delete(index_name) | ||
if 'acknowledged' not in result: | ||
return result | ||
return result['acknowledged'] | ||
|
||
def index_many(self, docs=None, doc_ids=None, doc_dict=None, index=None): | ||
doc_dict = self.__get_doc_dict(docs, doc_ids, doc_dict) | ||
body = [] | ||
for k, v in doc_dict.items(): | ||
body.append({'index': {'_index': index, '_id': k, 'retry_on_conflict': 3}}) | ||
body.append(v) | ||
pass | ||
index = self.__validate_index(index) | ||
try: | ||
index_result = self._engine.bulk(index=index, | ||
body=body, doc_type=DEFAULT_TYPE) | ||
LOGGER.info('indexed. result: {}'.format(index_result)) | ||
return self.__check_errors_for_bulk(index_result) | ||
except: | ||
LOGGER.exception('cannot add indices with ids: {} for index: {}'.format(list(doc_dict.keys()), index)) | ||
return doc_dict | ||
return | ||
|
||
def index_one(self, doc, doc_id, index=None): | ||
index = self.__validate_index(index) | ||
try: | ||
index_result = self._engine.index(index=index, | ||
body=doc, doc_type=DEFAULT_TYPE, id=doc_id) | ||
LOGGER.info('indexed. result: {}'.format(index_result)) | ||
pass | ||
except: | ||
LOGGER.exception('cannot add a new index with id: {} for index: {}'.format(doc_id, index)) | ||
return None | ||
return self | ||
|
||
def update_many(self, docs=None, doc_ids=None, doc_dict=None, index=None): | ||
doc_dict = self.__get_doc_dict(docs, doc_ids, doc_dict) | ||
body = [] | ||
for k, v in doc_dict.items(): | ||
body.append({'update': {'_index': index, '_id': k, 'retry_on_conflict': 3}}) | ||
body.append({'doc': v, 'doc_as_upsert': True}) | ||
pass | ||
index = self.__validate_index(index) | ||
try: | ||
index_result = self._engine.bulk(index=index, | ||
body=body, doc_type=DEFAULT_TYPE) | ||
LOGGER.info('indexed. result: {}'.format(index_result)) | ||
return self.__check_errors_for_bulk(index_result) | ||
except: | ||
LOGGER.exception('cannot update indices with ids: {} for index: {}'.format(list(doc_dict.keys()), | ||
index)) | ||
return doc_dict | ||
return | ||
|
||
def update_one(self, doc, doc_id, index=None): | ||
update_body = { | ||
'doc': doc, | ||
'doc_as_upsert': True | ||
} | ||
index = self.__validate_index(index) | ||
try: | ||
update_result = self._engine.update(index=index, | ||
id=doc_id, body=update_body, doc_type=DEFAULT_TYPE) | ||
LOGGER.info('updated. result: {}'.format(update_result)) | ||
pass | ||
except: | ||
LOGGER.exception('cannot update id: {} for index: {}'.format(doc_id, index)) | ||
return None | ||
return self | ||
|
||
@staticmethod | ||
def get_result_size(result): | ||
if isinstance(result['hits']['total'], dict): # fix for different datatype in elastic-search result | ||
return result['hits']['total']['value'] | ||
else: | ||
return result['hits']['total'] | ||
|
||
def query_with_scroll(self, dsl, querying_index=None): | ||
scroll_timeout = '30s' | ||
index = self.__validate_index(querying_index) | ||
dsl['size'] = 10000 # replacing with the maximum size to minimize number of scrolls | ||
params = { | ||
'index': index, | ||
'size': 10000, | ||
'scroll': scroll_timeout, | ||
'body': dsl, | ||
} | ||
first_batch = self._engine.search(**params) | ||
total_size = self.get_result_size(first_batch) | ||
current_size = len(first_batch['hits']['hits']) | ||
scroll_id = first_batch['_scroll_id'] | ||
while current_size < total_size: # need to scroll | ||
scrolled_result = self._engine.scroll(scroll_id=scroll_id, scroll=scroll_timeout) | ||
scroll_id = scrolled_result['_scroll_id'] | ||
scrolled_result_size = len(scrolled_result['hits']['hits']) | ||
if scrolled_result_size == 0: | ||
break | ||
else: | ||
current_size += scrolled_result_size | ||
first_batch['hits']['hits'].extend(scrolled_result['hits']['hits']) | ||
return first_batch | ||
|
||
def query(self, dsl, querying_index=None): | ||
index = self.__validate_index(querying_index) | ||
return self._engine.search(body=dsl, index=index) | ||
|
||
def __is_querying_next_page(self, targeted_size: int, current_size: int, total_size: int): | ||
if targeted_size < 0: | ||
return current_size > 0 | ||
return current_size > 0 and total_size < targeted_size | ||
|
||
def query_pages(self, dsl, querying_index=None): | ||
if 'sort' not in dsl: | ||
raise ValueError('missing `sort` in DSL. Make sure sorting is unique') | ||
index = self.__validate_index(querying_index) | ||
targeted_size = dsl['sort'] if 'size' in dsl else -1 | ||
dsl['size'] = 10000 # replacing with the maximum size to minimize number of scrolls | ||
params = { | ||
'index': index, | ||
'size': 10000, | ||
'body': dsl, | ||
} | ||
LOGGER.debug(f'dsl: {dsl}') | ||
result_list = [] | ||
total_size = 0 | ||
result_batch = self._engine.search(**params) | ||
result_list.extend(result_batch['hits']['hits']) | ||
current_size = len(result_batch['hits']['hits']) | ||
total_size += current_size | ||
while self.__is_querying_next_page(targeted_size, current_size, total_size): | ||
params['body']['search_after'] = result_batch['hits']['hits'][-1]['sort'] | ||
result_batch = self._engine.search(**params) | ||
result_list.extend(result_batch['hits']['hits']) | ||
current_size = len(result_batch['hits']['hits']) | ||
total_size += current_size | ||
return { | ||
'hits': { | ||
'hits': result_list, | ||
'total': total_size, | ||
} | ||
} | ||
|
||
def query_by_id(self, doc_id): | ||
index = self.__validate_index(None) | ||
dsl = { | ||
'query': { | ||
'term': {'_id': doc_id} | ||
} | ||
} | ||
result = self._engine.search(index=index, body=dsl) | ||
if self.get_result_size(result) < 1: | ||
return None | ||
return result['hits']['hits'][0] |
Oops, something went wrong.