Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve deepaas-cli for OSCAR integration #158

Merged
merged 11 commits into from
Jul 19, 2024
132 changes: 91 additions & 41 deletions deepaas/cmd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@


# import asyncio
import argparse
import ast
import deepaas
import json
import mimetypes
Expand All @@ -43,23 +45,37 @@

# Not all types are covered! If not listed, the type is 'str'
# see https://marshmallow.readthedocs.io/en/stable/marshmallow.fields.html

# Passing lists or dicts with argparse is not straight forward, so we use pass them as
# string and parse them with `ast.literal_eval`
# ref: https://stackoverflow.com/questions/7625786/type-dict-in-argparse-add-argument
# We follow a similar approach with bools
# ref: https://stackoverflow.com/a/59579733/18471590
# ref: https://stackoverflow.com/questions/715417/converting-from-a-string-to-boolean-in-python/18472142#18472142 # noqa


def str2bool(v):
return v.lower() in ("yes", "true", "t", "1")


FIELD_TYPE_CONVERTERS = {
fields.Bool: bool,
fields.Boolean: bool,
fields.Bool: str2bool,
fields.Boolean: str2bool,
fields.Date: str,
fields.DateTime: str,
fields.Dict: dict,
fields.Dict: ast.literal_eval,
fields.Email: str,
fields.Float: float,
fields.Int: int,
fields.Integer: int,
fields.List: list,
fields.List: ast.literal_eval,
fields.Str: str,
fields.String: str,
fields.Time: str,
fields.URL: str,
fields.Url: str,
fields.UUID: str,
fields.Field: str,
}


Expand Down Expand Up @@ -87,38 +103,43 @@ def _fields_to_dict(fields_in):
# initialise param with no "default", type "str" (!), empty "help"
param = {"default": None, "type": str, "help": ""}

# infer "type"
# see FIELD_TYPE_CONVERTERS for converting
# mashmallow field types to python types
# Initialize help string
val_help = val.metadata.get("description", "")
if "%" in val_help:
# argparse hates % sign:
# replace single occurancies of '%' with '%%'
# since "%%"" is accepted by argparse
val_help = re.sub(r"(?<!%)%(?!%)", r"%%", val_help)

# Infer "type"
val_type = type(val)
if val_type in FIELD_TYPE_CONVERTERS:
param["type"] = FIELD_TYPE_CONVERTERS[val_type]

if key == "files" or key == "urls":
param["type"] = str

# infer "required"
try:
val_req = val.required
except Exception:
val_req = False
param["required"] = val_req

# infer 'default'
# if the field is not required, there must be default value
if not val_req:
if val_type is fields.List:
val_help += '\nType: list, enclosed as string: "[...]"'
elif val_type is fields.Dict:
val_help += '\nType: dict, enclosed as string: "{...}"'
elif val_type in [fields.Bool, fields.Boolean]:
val_help += "\nType: bool"
else:
val_help += f"\nType: {param['type'].__name__}"
if val_type is fields.Field:
val_help += " (filepath)"

# Infer "required"
param["required"] = val.required
if not val.required:
param["default"] = val.missing
val_help += f"\nDefault: {val.missing}"
else:
val_help += "\n*Required*"

# infer 'help'
val_help = val.metadata["description"]
# argparse hates % sign:
if "%" in val_help:
# replace single occurancies of '%' with '%%'
# since "%%"" is accepted by argparse
val_help = re.sub(r"(?<!%)%(?!%)", r"%%", val_help)

# Add choices to help message
if "enum" in val.metadata.keys():
val_help = f"{val_help}. Choices: {val.metadata['enum']}"
val_help += f"\nChoices: {val.metadata['enum']}"

val_help = val_help.lstrip("\n") # remove escape when no description found
param["help"] = val_help

dict_out[key] = param
Expand Down Expand Up @@ -159,6 +180,18 @@ def _get_model_name(model_name=None):
sys.exit(1)


def _get_file_args(fields_in):
"""Function to retrieve a list of file-type fields
:param fields_in: mashmallow fields
:return: list
"""
file_fields = []
for k, v in fields_in.items():
if type(v) is fields.Field:
file_fields.append(k)
return file_fields


# Get the model name
model_name = CONF.model_name

Expand All @@ -168,34 +201,47 @@ def _get_model_name(model_name=None):
# model_obj = v2_wrapper.ModelWrapper(name=model_name,
# model_obj=model_obj)


# Once we know the model name,
# we get arguments for predict and train as dictionaries
predict_args = _fields_to_dict(model_obj.get_predict_args())
train_args = _fields_to_dict(model_obj.get_train_args())

# Find which of the arguments are going to be files
file_args = {}
file_args["predict"] = _get_file_args(model_obj.get_predict_args())
file_args["train"] = _get_file_args(model_obj.get_train_args())


# Function to add later these arguments to CONF via SubCommandOpt
def _add_methods(subparsers):
"""Function to add argparse subparsers via SubCommandOpt (see below)
for DEEPaaS methods get_metadata, warm, predict, train
"""

# Use RawTextHelpFormatter to allow for line breaks in argparse help messages.
def help_formatter(prog):
return argparse.RawTextHelpFormatter(prog, max_help_position=10)

# in case no method requested, we return get_metadata(). check main()
subparsers.required = False

get_metadata_parser = subparsers.add_parser( # noqa: F841
"get_metadata", help="get_metadata method"
"get_metadata",
help="get_metadata method",
formatter_class=help_formatter,
)

get_warm_parser = subparsers.add_parser( # noqa: F841
"warm",
help="warm method, e.g. to " "prepare the model for execution",
formatter_class=help_formatter,
)

# get predict arguments configured
predict_parser = subparsers.add_parser(
"predict", help="predict method, use " "predict --help for the full list"
"predict",
help="predict method, use " "predict --help for the full list",
formatter_class=help_formatter,
)

for key, val in predict_args.items():
Expand All @@ -208,7 +254,9 @@ def _add_methods(subparsers):
)
# get train arguments configured
train_parser = subparsers.add_parser(
"train", help="train method, use " "train --help for the full list"
"train",
help="train method, use " "train --help for the full list",
formatter_class=help_formatter,
)

for key, val in train_args.items():
Expand Down Expand Up @@ -285,29 +333,31 @@ def main():
if CONF.deepaas_with_multiprocessing:
mp.set_start_method("spawn", force=True)

# TODO(multi-file): change to many files ('for' itteration)
if CONF.methods.__contains__("files"):
if CONF.methods.files:
# Create file wrapper for file args (if provided)
for farg in file_args[CONF.methods.name]:
if getattr(CONF.methods, farg, None):
fpath = conf_vars[farg]

# create tmp file as later it supposed
# to be deleted by the application
temp = tempfile.NamedTemporaryFile()
temp.close()
# copy original file into tmp file
with open(conf_vars["files"], "rb") as f:
with open(fpath, "rb") as f:
with open(temp.name, "wb") as f_tmp:
for line in f:
f_tmp.write(line)

# create file object
file_type = mimetypes.MimeTypes().guess_type(conf_vars["files"])[0]
file_type = mimetypes.MimeTypes().guess_type(fpath)[0]
file_obj = v2_wrapper.UploadedFile(
name="data",
filename=temp.name,
content_type=file_type,
original_filename=conf_vars["files"],
original_filename=fpath,
)
# re-write 'files' parameter in conf_vars
conf_vars["files"] = file_obj
# re-write parameter in conf_vars
conf_vars[farg] = file_obj

# debug of input parameters
LOG.debug("[DEBUG provided options, conf_vars]: {}".format(conf_vars))
Expand Down
Loading