Skip to content

Commit eb41046

Browse files
authored
Merge 7f7f863 into d498fe3
2 parents d498fe3 + 7f7f863 commit eb41046

15 files changed

+497
-1
lines changed

cumulus_lambda_functions/authorization/__init__.py

Whitespace-only changes.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class UDSAuthorizorAbstract(ABC):
5+
@abstractmethod
6+
def authorize(self, username, resource, action) -> bool:
7+
return False
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import logging
2+
import os
3+
4+
from cumulus_lambda_functions.authorization.uds_authorizer_abstract import UDSAuthorizorAbstract
5+
from cumulus_lambda_functions.lib.aws.aws_cognito import AwsCognito
6+
from cumulus_lambda_functions.lib.aws.es_abstract import ESAbstract
7+
from cumulus_lambda_functions.lib.aws.es_factory import ESFactory
8+
9+
LOGGER = logging.getLogger(__name__)
10+
11+
12+
class UDSAuthorizorEsIdentityPool(UDSAuthorizorAbstract):
13+
14+
def __init__(self, user_pool_id: str) -> None:
15+
super().__init__()
16+
es_url = os.getenv('ES_URL') # TODO validation
17+
authorization_index = os.getenv('AUTHORIZATION_URL') # LDAP_Group_Permission
18+
es_port = int(os.getenv('ES_PORT', '443'))
19+
self.__cognito = AwsCognito(user_pool_id)
20+
self.__es: ESAbstract = ESFactory().get_instance('AWS',
21+
index=authorization_index,
22+
base_url=es_url,
23+
port=es_port)
24+
25+
def authorize(self, username, resource, action) -> bool:
26+
belonged_groups = set(self.__cognito.get_groups(username))
27+
authorized_groups = self.__es.query({
28+
'query': {
29+
'match_all': {} # TODO
30+
}
31+
})
32+
LOGGER.debug(f'belonged_groups for {username}: {belonged_groups}')
33+
authorized_groups = set([k['_source']['group_name'] for k in authorized_groups['hits']['hits']])
34+
LOGGER.debug(f'authorized_groups for {resource}-{action}: {authorized_groups}')
35+
if any([k in authorized_groups for k in belonged_groups]):
36+
LOGGER.debug(f'{username} is authorized for {resource}-{action}')
37+
return True
38+
LOGGER.debug(f'{username} is NOT authorized for {resource}-{action}')
39+
return False
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from cumulus_lambda_functions.lib.aws.aws_cred import AwsCred
2+
3+
4+
class AwsCognito(AwsCred):
5+
def __init__(self, user_pool_id: str):
6+
super().__init__()
7+
self.__cognito = self.get_client('cognito-idp')
8+
self.__user_pool_id = user_pool_id
9+
10+
def get_groups(self, username: str):
11+
response = self.__cognito.admin_list_groups_for_user(
12+
Username=username,
13+
UserPoolId=self.__user_pool_id,
14+
Limit=60,
15+
# NextToken='string'
16+
)
17+
if response is None or 'Groups' not in response:
18+
return []
19+
belonged_groups = [k['GroupName'] for k in response['Groups']]
20+
return belonged_groups

cumulus_lambda_functions/lib/aws/aws_cred.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ def __init__(self):
4343
else:
4444
LOGGER.debug('using default session as there is no aws_access_key_id')
4545

46+
@property
47+
def region(self):
48+
return self.__region
49+
50+
@region.setter
51+
def region(self, val):
52+
"""
53+
:param val:
54+
:return: None
55+
"""
56+
self.__region = val
57+
return
58+
4659
@property
4760
def boto3_session(self):
4861
return self.__boto3_session
@@ -56,6 +69,9 @@ def boto3_session(self, val):
5669
self.__boto3_session = val
5770
return
5871

72+
def get_session(self):
73+
return boto3.Session(**self.boto3_session)
74+
5975
def get_resource(self, service_name: str):
6076
return boto3.Session(**self.boto3_session).resource(service_name)
6177

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Union, Callable
3+
4+
DEFAULT_TYPE = '_doc'
5+
6+
7+
class ESAbstract(ABC):
8+
@abstractmethod
9+
def create_index(self, index_name, index_body):
10+
return
11+
12+
@abstractmethod
13+
def has_index(self, index_name):
14+
return
15+
16+
@abstractmethod
17+
def create_alias(self, index_name, alias_name):
18+
return
19+
20+
@abstractmethod
21+
def delete_index(self, index_name):
22+
return
23+
24+
@abstractmethod
25+
def index_many(self, docs=None, doc_ids=None, doc_dict=None, index=None):
26+
return
27+
28+
@abstractmethod
29+
def index_one(self, doc, doc_id, index=None):
30+
return
31+
32+
@abstractmethod
33+
def update_many(self, docs=None, doc_ids=None, doc_dict=None, index=None):
34+
return
35+
36+
@abstractmethod
37+
def update_one(self, doc, doc_id, index=None):
38+
return
39+
40+
@staticmethod
41+
@abstractmethod
42+
def get_result_size(result):
43+
return
44+
45+
@abstractmethod
46+
def query_with_scroll(self, dsl, querying_index=None):
47+
return
48+
49+
@abstractmethod
50+
def query(self, dsl, querying_index=None):
51+
return
52+
53+
@abstractmethod
54+
def query_pages(self, dsl, querying_index=None):
55+
return
56+
57+
@abstractmethod
58+
def query_by_id(self, doc_id):
59+
return
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from cumulus_lambda_functions.lib.aws.factory_abstract import FactoryAbstract
2+
3+
4+
class ESFactory(FactoryAbstract):
5+
NO_AUTH = 'NO_AUTH'
6+
AWS = 'AWS'
7+
8+
def get_instance(self, class_type, **kwargs):
9+
ct = class_type.upper()
10+
if ct == self.NO_AUTH:
11+
from cumulus_lambda_functions.lib.aws.es_middleware import ESMiddleware
12+
return ESMiddleware(kwargs['index'], kwargs['base_url'], port=kwargs['port'])
13+
if ct == self.AWS:
14+
from cumulus_lambda_functions.lib.aws.es_middleware_aws import EsMiddlewareAws
15+
return EsMiddlewareAws(kwargs['index'], kwargs['base_url'], port=kwargs['port'])
16+
raise ModuleNotFoundError(f'cannot find ES class for {ct}')
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import json
2+
import logging
3+
4+
from elasticsearch import Elasticsearch
5+
6+
from cumulus_lambda_functions.lib.aws.es_abstract import ESAbstract, DEFAULT_TYPE
7+
8+
LOGGER = logging.getLogger(__name__)
9+
10+
11+
class ESMiddleware(ESAbstract):
12+
13+
def __init__(self, index, base_url, port=443) -> None:
14+
if any([k is None for k in [index, base_url]]):
15+
raise ValueError(f'index or base_url is None')
16+
self.__index = index
17+
base_url = base_url.replace('https://', '') # hide https
18+
self._engine = Elasticsearch(hosts=[{'host': base_url, 'port': port}])
19+
20+
def __validate_index(self, index):
21+
if index is not None:
22+
return index
23+
if self.__index is not None:
24+
return self.__index
25+
raise ValueError('index value is NULL')
26+
27+
def __get_doc_dict(self, docs=None, doc_ids=None, doc_dict=None):
28+
if doc_dict is None and (docs is None and doc_ids is None):
29+
raise ValueError('must provide either doc dictionary or doc list & id list')
30+
if doc_dict is None: # it comes as a list
31+
if len(docs) != len(doc_ids):
32+
raise ValueError('length of doc and id is different')
33+
doc_dict = {k: v for k, v in zip(doc_ids, docs)}
34+
pass
35+
return doc_dict
36+
37+
def __check_errors_for_bulk(self, index_result):
38+
if 'errors' not in index_result or index_result['errors'] is False:
39+
return
40+
err_list = [[{'id': v['_id'], 'error': v['error']} for _, v in each.items() if 'error' in v] for each in
41+
index_result['items']]
42+
if len(err_list) < 1:
43+
return
44+
LOGGER.exception('failed to add some items. details: {}'.format(err_list))
45+
return err_list
46+
47+
def create_index(self, index_name, index_body):
48+
result = self._engine.indices.create(index=index_name, body=index_body, include_type_name=False)
49+
if 'acknowledged' not in result:
50+
return result
51+
return result['acknowledged']
52+
53+
def has_index(self, index_name):
54+
result = self._engine.indices.exists(index=index_name)
55+
return result
56+
57+
def create_alias(self, index_name, alias_name):
58+
result = self._engine.indices.put_alias(index_name, alias_name)
59+
if 'acknowledged' not in result:
60+
return result
61+
return result['acknowledged']
62+
63+
def delete_index(self, index_name):
64+
result = self._engine.indices.delete(index_name)
65+
if 'acknowledged' not in result:
66+
return result
67+
return result['acknowledged']
68+
69+
def index_many(self, docs=None, doc_ids=None, doc_dict=None, index=None):
70+
doc_dict = self.__get_doc_dict(docs, doc_ids, doc_dict)
71+
body = []
72+
for k, v in doc_dict.items():
73+
body.append({'index': {'_index': index, '_id': k, 'retry_on_conflict': 3}})
74+
body.append(v)
75+
pass
76+
index = self.__validate_index(index)
77+
try:
78+
index_result = self._engine.bulk(index=index,
79+
body=body, doc_type=DEFAULT_TYPE)
80+
LOGGER.info('indexed. result: {}'.format(index_result))
81+
return self.__check_errors_for_bulk(index_result)
82+
except:
83+
LOGGER.exception('cannot add indices with ids: {} for index: {}'.format(list(doc_dict.keys()), index))
84+
return doc_dict
85+
return
86+
87+
def index_one(self, doc, doc_id, index=None):
88+
index = self.__validate_index(index)
89+
try:
90+
index_result = self._engine.index(index=index,
91+
body=doc, doc_type=DEFAULT_TYPE, id=doc_id)
92+
LOGGER.info('indexed. result: {}'.format(index_result))
93+
pass
94+
except:
95+
LOGGER.exception('cannot add a new index with id: {} for index: {}'.format(doc_id, index))
96+
return None
97+
return self
98+
99+
def update_many(self, docs=None, doc_ids=None, doc_dict=None, index=None):
100+
doc_dict = self.__get_doc_dict(docs, doc_ids, doc_dict)
101+
body = []
102+
for k, v in doc_dict.items():
103+
body.append({'update': {'_index': index, '_id': k, 'retry_on_conflict': 3}})
104+
body.append({'doc': v, 'doc_as_upsert': True})
105+
pass
106+
index = self.__validate_index(index)
107+
try:
108+
index_result = self._engine.bulk(index=index,
109+
body=body, doc_type=DEFAULT_TYPE)
110+
LOGGER.info('indexed. result: {}'.format(index_result))
111+
return self.__check_errors_for_bulk(index_result)
112+
except:
113+
LOGGER.exception('cannot update indices with ids: {} for index: {}'.format(list(doc_dict.keys()),
114+
index))
115+
return doc_dict
116+
return
117+
118+
def update_one(self, doc, doc_id, index=None):
119+
update_body = {
120+
'doc': doc,
121+
'doc_as_upsert': True
122+
}
123+
index = self.__validate_index(index)
124+
try:
125+
update_result = self._engine.update(index=index,
126+
id=doc_id, body=update_body, doc_type=DEFAULT_TYPE)
127+
LOGGER.info('updated. result: {}'.format(update_result))
128+
pass
129+
except:
130+
LOGGER.exception('cannot update id: {} for index: {}'.format(doc_id, index))
131+
return None
132+
return self
133+
134+
@staticmethod
135+
def get_result_size(result):
136+
if isinstance(result['hits']['total'], dict): # fix for different datatype in elastic-search result
137+
return result['hits']['total']['value']
138+
else:
139+
return result['hits']['total']
140+
141+
def query_with_scroll(self, dsl, querying_index=None):
142+
scroll_timeout = '30s'
143+
index = self.__validate_index(querying_index)
144+
dsl['size'] = 10000 # replacing with the maximum size to minimize number of scrolls
145+
params = {
146+
'index': index,
147+
'size': 10000,
148+
'scroll': scroll_timeout,
149+
'body': dsl,
150+
}
151+
first_batch = self._engine.search(**params)
152+
total_size = self.get_result_size(first_batch)
153+
current_size = len(first_batch['hits']['hits'])
154+
scroll_id = first_batch['_scroll_id']
155+
while current_size < total_size: # need to scroll
156+
scrolled_result = self._engine.scroll(scroll_id=scroll_id, scroll=scroll_timeout)
157+
scroll_id = scrolled_result['_scroll_id']
158+
scrolled_result_size = len(scrolled_result['hits']['hits'])
159+
if scrolled_result_size == 0:
160+
break
161+
else:
162+
current_size += scrolled_result_size
163+
first_batch['hits']['hits'].extend(scrolled_result['hits']['hits'])
164+
return first_batch
165+
166+
def query(self, dsl, querying_index=None):
167+
index = self.__validate_index(querying_index)
168+
return self._engine.search(body=dsl, index=index)
169+
170+
def __is_querying_next_page(self, targeted_size: int, current_size: int, total_size: int):
171+
if targeted_size < 0:
172+
return current_size > 0
173+
return current_size > 0 and total_size < targeted_size
174+
175+
def query_pages(self, dsl, querying_index=None):
176+
if 'sort' not in dsl:
177+
raise ValueError('missing `sort` in DSL. Make sure sorting is unique')
178+
index = self.__validate_index(querying_index)
179+
targeted_size = dsl['sort'] if 'size' in dsl else -1
180+
dsl['size'] = 10000 # replacing with the maximum size to minimize number of scrolls
181+
params = {
182+
'index': index,
183+
'size': 10000,
184+
'body': dsl,
185+
}
186+
LOGGER.debug(f'dsl: {dsl}')
187+
result_list = []
188+
total_size = 0
189+
result_batch = self._engine.search(**params)
190+
result_list.extend(result_batch['hits']['hits'])
191+
current_size = len(result_batch['hits']['hits'])
192+
total_size += current_size
193+
while self.__is_querying_next_page(targeted_size, current_size, total_size):
194+
params['body']['search_after'] = result_batch['hits']['hits'][-1]['sort']
195+
result_batch = self._engine.search(**params)
196+
result_list.extend(result_batch['hits']['hits'])
197+
current_size = len(result_batch['hits']['hits'])
198+
total_size += current_size
199+
return {
200+
'hits': {
201+
'hits': result_list,
202+
'total': total_size,
203+
}
204+
}
205+
206+
def query_by_id(self, doc_id):
207+
index = self.__validate_index(None)
208+
dsl = {
209+
'query': {
210+
'term': {'_id': doc_id}
211+
}
212+
}
213+
result = self._engine.search(index=index, body=dsl)
214+
if self.get_result_size(result) < 1:
215+
return None
216+
return result['hits']['hits'][0]

0 commit comments

Comments
 (0)