Skip to content

Commit

Permalink
Fix displaying custom predictors (#7)
Browse files Browse the repository at this point in the history
* Handle custom predictors

Signed-off-by: Mark Winter <[email protected]>

* Change return type to PredictorType

Signed-off-by: Mark Winter <[email protected]>
  • Loading branch information
markwinter committed Dec 23, 2021
1 parent c756bfd commit 483ab7e
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 71 deletions.
4 changes: 2 additions & 2 deletions frontend/src/app/pages/index/index.component.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { Component, OnInit, OnDestroy } from '@angular/core';
import { MWABackendService } from 'src/app/services/backend.service';
import { Clipboard } from '@angular/cdk-experimental/clipboard';
import {
PredictorSpec,
InferenceServiceK8s,
InferenceServiceIR,
} from 'src/app/types/kfserving/v1beta1';
Expand Down Expand Up @@ -186,9 +185,10 @@ export class IndexComponent implements OnInit, OnDestroy {
svc.ui.actions.copy = this.getCopyActionStatus(svc);
svc.ui.actions.delete = this.getDeletionActionStatus(svc);

const predictorType = getPredictorType(svc.spec.predictor);
const predictor = getPredictorExtensionSpec(svc.spec.predictor);

svc.ui.predictorType = getPredictorType(svc.spec.predictor);
svc.ui.predictorType = predictorType;
svc.ui.runtimeVersion = predictor.runtimeVersion;
svc.ui.storageUri = predictor.storageUri;
svc.ui.protocolVersion = predictor.protocolVersion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
{{ basePredictor?.storageUri }}
</lib-details-list-item>

<lib-details-list-item key="Runtime">
{{ predictorType }} {{ basePredictor.runtimeVersion }}
<lib-details-list-item key="Predictor">
{{ predictorType }}
</lib-details-list-item>

<lib-details-list-item
key="Protocol version"
*ngIf="basePredictor.protocolVersion"
>
<lib-details-list-item key="Runtime" *ngIf="basePredictor?.runtimeVersion">
{{ basePredictor.runtimeVersion }}
</lib-details-list-item>

<lib-details-list-item key="Protocol Version" *ngIf="basePredictor?.protocolVersion">
{{ basePredictor.protocolVersion }}
</lib-details-list-item>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
{{ basePredictor?.storageUri }}
</lib-details-list-item>

<lib-details-list-item key="Runtime">
{{ predictorType }} {{ basePredictor.runtimeVersion }}
<lib-details-list-item key="Predictor">
{{ predictorType }}
</lib-details-list-item>

<lib-details-list-item key="Protocol Version">
<lib-details-list-item key="Runtime" *ngIf="basePredictor?.runtimeVersion">
{{ basePredictor.runtimeVersion }}
</lib-details-list-item>

<lib-details-list-item key="Protocol Version" *ngIf="basePredictor?.protocolVersion">
{{ basePredictor.protocolVersion }}
</lib-details-list-item>

Expand Down
84 changes: 24 additions & 60 deletions frontend/src/app/shared/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { V1Container } from '@kubernetes/client-node';
import {
InferenceServiceK8s,
PredictorSpec,
PredictorType,
PredictorExtensionSpec,
ExplainerSpec,
} from '../types/kfserving/v1beta1';
Expand Down Expand Up @@ -107,78 +108,41 @@ export function getK8sObjectStatus(obj: K8sObject): [string, string] {
}

// functions for processing the InferenceService spec
export function getPredictorType(predictor: PredictorSpec): string {
if ('tensorflow' in predictor) {
return 'Tensorflow';
}

if ('triton' in predictor) {
return 'Triton';
}

if ('sklearn' in predictor) {
return 'SKLearn';
}

if ('onnx' in predictor) {
return 'Onnx';
}

if ('pytorch' in predictor) {
return 'PyTorch';
}

if ('xgboost' in predictor) {
return 'XGBoost';
}

if ('pmml' in predictor) {
return 'PMML';
}

if ('lightgbm' in predictor) {
return 'LightGBM';
export function getPredictorType(predictor: PredictorSpec): PredictorType {
for (const predictorType of Object.values(PredictorType)) {
if (predictorType in predictor) {
return predictorType;
}
}

return 'Custom';
return PredictorType.Custom;
}

export function getPredictorExtensionSpec(
predictor: PredictorSpec,
): PredictorExtensionSpec {
if ('tensorflow' in predictor) {
return predictor.tensorflow;
}

if ('triton' in predictor) {
return predictor.triton;
}

if ('sklearn' in predictor) {
return predictor.sklearn;
}

if ('onnx' in predictor) {
return predictor.onnx;
}

if ('pytorch' in predictor) {
return predictor.pytorch;
for (const predictorType of Object.values(PredictorType)) {
if (predictorType in predictor) {
return predictor[predictorType];
}
}

if ('xgboost' in predictor) {
return predictor.xgboost;
}
// In the case of Custom predictors, set the additional PredictorExtensionSpec fields
// manually here
const spec = predictor.containers[0] as PredictorExtensionSpec;
spec.runtimeVersion = '';
spec.protocolVersion = '';

if ('pmml' in predictor) {
return predictor.pmml;
}

if ('lightgbm' in predictor) {
return predictor.lightgbm;
if (predictor.containers[0].env) {
const storageUri = predictor.containers[0].env.find(
envVar => envVar.name.toLowerCase() === 'storage_uri'
);
if (storageUri) {
spec.storageUri = storageUri.value;
}
}

return null;
return spec;
}

export function getExplainerContainer(explainer: ExplainerSpec): V1Container {
Expand Down
12 changes: 12 additions & 0 deletions frontend/src/app/types/kfserving/v1beta1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ export interface InferenceServiceSpec {
transformer: TransformerSpec;
}

export enum PredictorType {
Tensorflow = 'tensorflow',
Triton = 'triton',
Sklean = 'sklearn',
Onnx = 'onnx',
Pytorch = 'pytorch',
Xgboost = 'xgboost',
Pmml = 'pmml',
Lightgbm = 'lightgbm',
Custom = 'custom',
}

export interface PredictorSpec extends V1PodSpec, ComponentExtensionSpec {
sklearn?: PredictorExtensionSpec;
xgboost?: PredictorExtensionSpec;
Expand Down

0 comments on commit 483ab7e

Please sign in to comment.