diff --git a/components/alibi-detect-server/adserver/base/storage.py b/components/alibi-detect-server/adserver/base/storage.py index 25b314a58c..38dd178330 100644 --- a/components/alibi-detect-server/adserver/base/storage.py +++ b/components/alibi-detect-server/adserver/base/storage.py @@ -2,7 +2,7 @@ import sys import logging import tempfile -from distutils.util import strtobool +from typing import Optional ARTIFACT_DOWNLOAD_LOCATION = os.environ.get("DRIFT_ARTIFACTS_DIR", "/tmp") @@ -18,13 +18,13 @@ class Rclone: - def __init__(self, cfg_file: str = None): + def __init__(self, cfg_file: Optional[str] = None): self.cfg_file = cfg_file - def copy(self, src: str, dest: str = None): + def copy(self, src: str, dest: Optional[str] = None): if rclone is None: raise RuntimeError( - "rclone binary not found - rclone-based storage funcionality disabled" + "rclone binary not found - rclone-based storage functionality disabled" ) if dest is None: diff --git a/components/alibi-detect-server/adserver/cm_model.py b/components/alibi-detect-server/adserver/cm_model.py index 2833858692..33f4bdcc46 100644 --- a/components/alibi-detect-server/adserver/cm_model.py +++ b/components/alibi-detect-server/adserver/cm_model.py @@ -21,7 +21,7 @@ SELDON_PREDICTOR_ID = DEFAULT_LABELS["predictor_name"] -def _load_class_module(module_path: str) -> str: +def _load_class_module(module_path: str): components = module_path.split(".") mod = __import__(".".join(components[:-1])) for comp in components[1:]: @@ -32,7 +32,7 @@ def _load_class_module(module_path: str) -> str: class CustomMetricsModel(CEModel): # pylint:disable=c-extension-no-member def __init__( - self, name: str, storage_uri: str, elasticsearch_uri: str = None, model=None + self, name: str, storage_uri: str, elasticsearch_uri: Optional[str] = None, model=None ): """ Custom Metrics Model diff --git a/components/alibi-detect-server/adserver/server.py b/components/alibi-detect-server/adserver/server.py index 2e4c0c46ff..95351003de 100644 --- a/components/alibi-detect-server/adserver/server.py +++ b/components/alibi-detect-server/adserver/server.py @@ -39,7 +39,7 @@ def __init__( event_type: str, event_source: str, http_port: int = DEFAULT_HTTP_PORT, - reply_url: str = None, + reply_url: Optional[str] = None, ): """ CloudEvents server @@ -146,29 +146,21 @@ def get_request_handler(protocol, request: Dict) -> RequestHandler: raise Exception(f"Unknown protocol {protocol}") -def sendCloudEvent(event: v1.Event, url: str): +def forward_request(headers, data, url): """ - Send CloudEvent + Forward request Parameters ---------- - event - CloudEvent to send + headers + Headers to forward + data + Data to forward url - Url to send event + Url to forward to """ - http_marshaller = marshaller.NewDefaultHTTPMarshaller() - binary_headers, binary_data = http_marshaller.ToRequest( - event, converters.TypeBinary, json.dumps - ) - - logging.info("binary CloudEvent") - for k, v in binary_headers.items(): - logging.info("{0}: {1}\r\n".format(k, v)) - logging.info(binary_data) - - response = requests.post(url, headers=binary_headers, data=binary_data) + response = requests.post(url, headers=headers, data=data) response.raise_for_status() @@ -252,27 +244,73 @@ def post(self): else: logging.error("Metrics returned are invalid: " + str(runtime_metrics)) - if response.data is not None: + revent = create_cloud_event( + response.data, + self.event_type, + self.event_source, + event_id=event.EventID(), + extensions=event.Extensions(), + ) + if response.data is not None: # Create event from response if reply_url is active + revent_headers, revent_data = http_marshaller.ToRequest( + revent, converters.TypeBinary, json.dumps + ) + if not self.reply_url == "": - if event.EventID() is None or event.EventID() == "": - resp_event_id = uuid.uuid1().hex - else: - resp_event_id = event.EventID() - revent = ( - v1.Event() - .SetContentType("application/json") - .SetData(response.data) - .SetEventID(resp_event_id) - .SetSource(self.event_source) - .SetEventType(self.event_type) - .SetExtensions(event.Extensions()) - ) logging.debug(json.dumps(revent.Properties())) - sendCloudEvent(revent, self.reply_url) - self.write(json.dumps(response.data)) + logging.info("binary CloudEvent") + for k, v in revent_headers.items(): + logging.info("{0}: {1}\r\n".format(k, v)) + logging.info(revent_data) + forward_request(revent_headers, revent_data, self.reply_url) + + self.set_header("Content-Type", "application/json") + for headers in revent_headers: + self.set_header(headers, revent_headers[headers]) + self.write(revent_data) + + +def create_cloud_event( + data: dict, + event_type: str, + event_source: str, + extensions: dict, + event_id: str = None, +) -> v1.Event: + """ + Create a CloudEvent + + Parameters + ---------- + data + The data to send + event_type + The CE event type + event_source + The CE event source + extensions + Any extensions to add + event_id + The event id + Returns + ------- + A CloudEvent + """ + if event_id is None or event_id == "": + event_id = uuid.uuid1().hex + + event = ( + v1.Event() + .SetData(data) + .SetEventID(event_id if event_id else str(uuid.uuid1().hex)) + .SetSource(event_source) + .SetEventType(event_type) + .SetExtensions(extensions) + ) + return event class LivenessHandler(tornado.web.RequestHandler): def get(self): diff --git a/components/alibi-detect-server/adserver/tests/test_server.py b/components/alibi-detect-server/adserver/tests/test_server.py index 6ca27b4f4a..ad5a56a142 100644 --- a/components/alibi-detect-server/adserver/tests/test_server.py +++ b/components/alibi-detect-server/adserver/tests/test_server.py @@ -4,6 +4,9 @@ from typing import List, Dict, Optional, Union import json import requests_mock +from cloudevents.sdk import converters +from cloudevents.sdk import marshaller +from cloudevents.sdk.event import v1 class TestProtocol(AsyncHTTPTestCase): @@ -74,11 +77,31 @@ def test_basic(self): ) self.assertEqual(response.code, 200) expectedResponse = DummyModel.getResponse().data + # assert that the expected response conforms to the CloudEvent spec + event = v1.Event() + http_marshaller = marshaller.NewDefaultHTTPMarshaller() + try: + event = http_marshaller.FromRequest( + event, response.headers, response.body, json.loads + ) + except Exception as e: + assert False, f"Failed to unmarshall data with error: {type(e).__name__}('{e}')" + + # assert cloud event properties have been set correctly in response + self.assertEqual(event.Data(), expectedResponse) + self.assertEqual(event.Source(), self.eventSource) + self.assertEqual(event.EventType(), self.eventType) + self.assertEqual(event.ContentType(), "application/json") + self.assertEqual(event.EventID(), "1234") + self.assertEqual(event.CloudEventVersion(), "1.0") self.assertEqual(response.body.decode("utf-8"), json.dumps(expectedResponse)) + + # assert requests have been made with the correct headers and data self.assertEqual(m.request_history[0].json(), expectedResponse) headers: Dict = m.request_history[0]._request.headers self.assertEqual(headers["ce-source"], self.eventSource) self.assertEqual(headers["ce-type"], self.eventType) + self.assertNotIn("ce-datacontenttype", headers) class TestKFservingV2HttpModel(AsyncHTTPTestCase):