Skip to content

Commit 23fb773

Browse files
authored
Update create model validation; Ensure Base image is set for LISA hosted models;
1 parent 241631b commit 23fb773

File tree

5 files changed

+60
-7
lines changed

5 files changed

+60
-7
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ repos:
7878
args:
7979
- --max-line-length=120
8080
- --extend-immutable-calls=Query,fastapi.Depends,fastapi.params.Depends
81-
- --ignore=B008,E203 # Ignore error for function calls in argument defaults
81+
- --ignore=B008,E203, W503 # Ignore error for function calls in argument defaults
8282
exclude: ^(__init__.py$|.*\/__init__.py$)
8383

8484

lambda/models/handler/create_model_handler.py

+36
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __call__(self, create_request: CreateModelRequest) -> CreateModelResponse:
3535
if table_item:
3636
raise ModelAlreadyExistsError(f"Model '{model_id}' already exists. Please select another name.")
3737

38+
self.validate(create_request)
39+
3840
self._stepfunctions.start_execution(
3941
stateMachineArn=os.environ["CREATE_SFN_ARN"], input=create_request.model_dump_json()
4042
)
@@ -46,3 +48,37 @@ def __call__(self, create_request: CreateModelRequest) -> CreateModelResponse:
4648
}
4749
)
4850
return CreateModelResponse(model=lisa_model)
51+
52+
@staticmethod
53+
def validate(create_request: CreateModelRequest) -> None:
54+
# The below check ensures that the model is LISA hosted
55+
if (
56+
create_request.containerConfig is not None
57+
and create_request.autoScalingConfig is not None
58+
and create_request.loadBalancerConfig is not None
59+
):
60+
if create_request.containerConfig.image.baseImage is None:
61+
raise ValueError("Base image must be provided for LISA hosted model.")
62+
63+
# Validate values relative to current ASG. All conflicting request values have been validated as part of the
64+
# AutoScalingInstanceConfig model validations, so those are not duplicated here.
65+
if create_request.autoScalingConfig is not None:
66+
# Min capacity can't be greater than the deployed ASG's max capacity
67+
if (
68+
create_request.autoScalingConfig.minCapacity is not None
69+
and create_request.autoScalingConfig.maxCapacity is not None
70+
and create_request.autoScalingConfig.minCapacity > create_request.autoScalingConfig.maxCapacity
71+
):
72+
raise ValueError(
73+
f"Min capacity cannot exceed ASG max of {create_request.autoScalingConfig.maxCapacity}."
74+
)
75+
76+
# Max capacity can't be less than the deployed ASG's min capacity
77+
if (
78+
create_request.autoScalingConfig.maxCapacity is not None
79+
and create_request.autoScalingConfig.minCapacity is not None
80+
and create_request.autoScalingConfig.maxCapacity < create_request.autoScalingConfig.minCapacity
81+
):
82+
raise ValueError(
83+
f"Max capacity cannot be less than ASG min of {create_request.autoScalingConfig.minCapacity}."
84+
)

lib/user-interface/react/src/components/model-management/create-model/CreateModelModal.tsx

+2-3
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement {
170170
}
171171
}
172172

173-
const requiredFields = [['modelId', 'modelName'], [], [], [], []];
173+
const requiredFields = [['modelId', 'modelName'], ['containerConfig.image.baseImage'], [], [], []];
174174

175175
useEffect(() => {
176176
const parsedValue = _.mergeWith({}, initialForm, props.selectedItems[0], (a: IModelRequest, b: IModelRequest) => b === null ? a : undefined);
@@ -318,8 +318,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement {
318318
case 'next':
319319
case 'skip':
320320
{
321-
touchFields(requiredFields[state.activeStepIndex]);
322-
if (isValid) {
321+
if (touchFields(requiredFields[state.activeStepIndex]) && isValid) {
323322
setState({
324323
...state,
325324
activeStepIndex: event.detail.requestedStepIndex,

lib/user-interface/react/src/shared/model/model-management.model.ts

+11
Original file line numberDiff line numberDiff line change
@@ -230,5 +230,16 @@ export const ModelRequestSchema = z.object({
230230
});
231231
}
232232
}
233+
234+
const baseImageValidator = z.string().min(1, {message: 'Required for LISA hosted models.'});
235+
const baseImageResult = baseImageValidator.safeParse(value.containerConfig.image.baseImage);
236+
if (baseImageResult.success === false) {
237+
for (const error of baseImageResult.error.errors) {
238+
context.addIssue({
239+
...error,
240+
path: ['containerConfig', 'image', 'baseImage']
241+
});
242+
}
243+
}
233244
}
234245
});

lib/user-interface/react/src/shared/validation/index.ts

+10-3
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ export type SetFieldsFunction = (
112112
) => void;
113113

114114

115-
export type TouchFieldsFunction = (fields: string[], method?: ValidationTouchActionMethod) => void;
115+
export type TouchFieldsFunction = (fields: string[], method?: ValidationTouchActionMethod) => boolean;
116116

117117

118118
/**
@@ -268,7 +268,7 @@ export const useValidationReducer = <F, S extends ValidationReducerBaseState<F>>
268268
return {
269269
state,
270270
errors,
271-
isValid: parseResult.success,
271+
isValid: Object.keys(errors).length === 0,
272272
setState: (newState: Partial<S>, method: ValidationStateActionMethod = ModifyMethod.Default) => {
273273
setState({
274274
type: ValidationReducerActionTypes.STATE,
@@ -289,12 +289,19 @@ export const useValidationReducer = <F, S extends ValidationReducerBaseState<F>>
289289
touchFields: (
290290
fields: string[],
291291
method: ValidationTouchActionMethod = ModifyMethod.Default
292-
) => {
292+
): boolean => {
293293
setState({
294294
type: ValidationReducerActionTypes.TOUCH,
295295
method,
296296
fields,
297297
} as ValidationTouchAction);
298+
const parseResult = formSchema.safeParse({...state.form, ...{touched: fields}});
299+
if (!parseResult.success) {
300+
errors = issuesToErrors(parseResult.error.issues, fields.reduce((acc, key) => {
301+
acc[key] = true; return acc;
302+
}, {}));
303+
}
304+
return Object.keys(errors).length === 0;
298305
},
299306
};
300307
};

0 commit comments

Comments
 (0)