Skip to content

Commit

Permalink
Merge pull request #163 from trivenay/main
Browse files Browse the repository at this point in the history
Raise all init errors in init instead of suppressing them until the first invoke
  • Loading branch information
trivenay authored Jul 31, 2024
2 parents dc83706 + bd40404 commit 904ef66
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 70 deletions.
31 changes: 13 additions & 18 deletions awslambdaric/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,36 +37,30 @@ def _get_handler(handler):
try:
(modname, fname) = handler.rsplit(".", 1)
except ValueError as e:
fault = FaultException(
raise FaultException(
FaultException.MALFORMED_HANDLER_NAME,
"Bad handler '{}': {}".format(handler, str(e)),
)
return make_fault_handler(fault)

try:
if modname.split(".")[0] in sys.builtin_module_names:
fault = FaultException(
raise FaultException(
FaultException.BUILT_IN_MODULE_CONFLICT,
"Cannot use built-in module {} as a handler module".format(modname),
)
return make_fault_handler(fault)
m = importlib.import_module(modname.replace("/", "."))
except ImportError as e:
fault = FaultException(
raise FaultException(
FaultException.IMPORT_MODULE_ERROR,
"Unable to import module '{}': {}".format(modname, str(e)),
)
request_handler = make_fault_handler(fault)
return request_handler
except SyntaxError as e:
trace = [' File "%s" Line %s\n %s' % (e.filename, e.lineno, e.text)]
fault = FaultException(
raise FaultException(
FaultException.USER_CODE_SYNTAX_ERROR,
"Syntax error in module '{}': {}".format(modname, str(e)),
trace,
)
request_handler = make_fault_handler(fault)
return request_handler

try:
request_handler = getattr(m, fname)
Expand All @@ -76,15 +70,8 @@ def _get_handler(handler):
"Handler '{}' missing on module '{}'".format(fname, modname),
None,
)
request_handler = make_fault_handler(fault)
return request_handler


def make_fault_handler(fault):
def result(*args):
raise fault

return result
return request_handler


def make_error(
Expand Down Expand Up @@ -475,15 +462,23 @@ def run(app_root, handler, lambda_runtime_api_addr):
lambda_runtime_client = LambdaRuntimeClient(
lambda_runtime_api_addr, use_thread_for_polling_next
)
error_result = None

try:
_setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink)
global _GLOBAL_AWS_REQUEST_ID

request_handler = _get_handler(handler)
except FaultException as e:
error_result = make_error(
e.msg,
e.exception_type,
e.trace,
)
except Exception:
error_result = build_fault_result(sys.exc_info(), None)

if error_result is not None:
log_error(error_result, log_sink)
lambda_runtime_client.post_init_error(to_json(error_result))

Expand Down
60 changes: 8 additions & 52 deletions tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,43 +603,6 @@ def raise_exception_handler(json_input, lambda_context):

self.assertEqual(mock_stdout.getvalue(), error_logs)

# The order of patches matter. Using MagicMock resets sys.stdout to the default.
@patch("importlib.import_module")
@patch("sys.stdout", new_callable=StringIO)
def test_handle_event_request_fault_exception_logging_syntax_error(
self, mock_stdout, mock_import_module
):
try:
eval("-")
except SyntaxError as e:
syntax_error = e

mock_import_module.side_effect = syntax_error

response_handler = bootstrap._get_handler("a.b")

bootstrap.handle_event_request(
self.lambda_runtime,
response_handler,
"invoke_id",
self.event_body,
"application/json",
{},
{},
"invoked_function_arn",
0,
bootstrap.StandardLogSink(),
)
error_logs = (
lambda_unhandled_exception_warning_message
+ f"[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': {syntax_error}\r"
)
error_logs += "Traceback (most recent call last):\r"
error_logs += '  File "<string>" Line 1\r'
error_logs += "    -\n"

self.assertEqual(mock_stdout.getvalue(), error_logs)


class TestXrayFault(unittest.TestCase):
def test_make_xray(self):
Expand Down Expand Up @@ -717,10 +680,8 @@ def __eq__(self, other):

def test_get_event_handler_bad_handler(self):
handler_name = "bad_handler"
response_handler = bootstrap._get_handler(handler_name)
with self.assertRaises(FaultException) as cm:
response_handler()

response_handler = bootstrap._get_handler(handler_name)
returned_exception = cm.exception
self.assertEqual(
self.FaultExceptionMatcher(
Expand All @@ -732,9 +693,8 @@ def test_get_event_handler_bad_handler(self):

def test_get_event_handler_import_error(self):
handler_name = "no_module.handler"
response_handler = bootstrap._get_handler(handler_name)
with self.assertRaises(FaultException) as cm:
response_handler()
response_handler = bootstrap._get_handler(handler_name)
returned_exception = cm.exception
self.assertEqual(
self.FaultExceptionMatcher(
Expand All @@ -757,10 +717,9 @@ def test_get_event_handler_syntax_error(self):
filename_w_ext = os.path.basename(tmp_file.name)
filename, _ = os.path.splitext(filename_w_ext)
handler_name = "{}.syntax_error".format(filename)
response_handler = bootstrap._get_handler(handler_name)

with self.assertRaises(FaultException) as cm:
response_handler()
response_handler = bootstrap._get_handler(handler_name)
returned_exception = cm.exception
self.assertEqual(
self.FaultExceptionMatcher(
Expand All @@ -782,9 +741,8 @@ def test_get_event_handler_missing_error(self):
filename_w_ext = os.path.basename(tmp_file.name)
filename, _ = os.path.splitext(filename_w_ext)
handler_name = "{}.my_handler".format(filename)
response_handler = bootstrap._get_handler(handler_name)
with self.assertRaises(FaultException) as cm:
response_handler()
response_handler = bootstrap._get_handler(handler_name)
returned_exception = cm.exception
self.assertEqual(
self.FaultExceptionMatcher(
Expand All @@ -801,9 +759,8 @@ def test_get_event_handler_slash(self):
response_handler()

def test_get_event_handler_build_in_conflict(self):
response_handler = bootstrap._get_handler("sys.hello")
with self.assertRaises(FaultException) as cm:
response_handler()
response_handler = bootstrap._get_handler("sys.hello")
returned_exception = cm.exception
self.assertEqual(
self.FaultExceptionMatcher(
Expand Down Expand Up @@ -1452,9 +1409,8 @@ def test_set_log_level_with_dictConfig(self, mock_stderr, mock_stdout):


class TestBootstrapModule(unittest.TestCase):
@patch("awslambdaric.bootstrap.handle_event_request")
@patch("awslambdaric.bootstrap.LambdaRuntimeClient")
def test_run(self, mock_runtime_client, mock_handle_event_request):
def test_run(self, mock_runtime_client):
expected_app_root = "/tmp/test/app_root"
expected_handler = "app.my_test_handler"
expected_lambda_runtime_api_addr = "test_addr"
Expand All @@ -1467,12 +1423,12 @@ def test_run(self, mock_runtime_client, mock_handle_event_request):
MagicMock(),
]

with self.assertRaises(TypeError):
with self.assertRaises(SystemExit) as cm:
bootstrap.run(
expected_app_root, expected_handler, expected_lambda_runtime_api_addr
)

mock_handle_event_request.assert_called_once()
self.assertEqual(cm.exception.code, 1)

@patch(
"awslambdaric.bootstrap.LambdaLoggerHandler",
Expand Down

0 comments on commit 904ef66

Please sign in to comment.