Skip to content

Commit 5dd8bce

Browse files
authored
Feature/rest api sdk (#220)
Add RAG SDK functions
1 parent e14eb71 commit 5dd8bce

31 files changed

+835
-198
lines changed

Makefile

+1
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ createPythonEnvironment:
154154
installPythonRequirements:
155155
pip3 install pip --upgrade
156156
pip3 install -r requirements-dev.txt
157+
pip3 install -e lisa-sdk
157158

158159

159160
## Set up TypeScript interpreter environment

lambda/authorizer/lambda_functions.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,22 @@ def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: i
5454
jwt_groups_property = os.environ.get("JWT_GROUPS_PROP", "")
5555

5656
deny_policy = generate_policy(effect="Deny", resource=event["methodArn"])
57-
57+
groups: str
5858
if id_token in get_management_tokens():
59-
allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username="lisa-management-token")
59+
username = "lisa-management-token"
60+
# Add management token to Admin groups
61+
groups = json.dumps([admin_group])
62+
allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username=username)
63+
allow_policy["context"] = {"username": username, "groups": groups}
6064
logger.debug(f"Generated policy: {allow_policy}")
6165
return allow_policy
6266

6367
if jwt_data := id_token_is_valid(id_token=id_token, client_id=client_id, authority=authority):
6468
is_admin_user = is_admin(jwt_data, admin_group, jwt_groups_property)
65-
groups = get_property_path(jwt_data, jwt_groups_property)
69+
groups = json.dumps(get_property_path(jwt_data, jwt_groups_property) or [])
6670
username = find_jwt_username(jwt_data)
6771
allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username=username)
68-
allow_policy["context"] = {"username": username, "groups": json.dumps(groups or [])}
72+
allow_policy["context"] = {"username": username, "groups": groups}
6973

7074
if requested_resource.startswith("/models") and not is_admin_user:
7175
# non-admin users can still list models
@@ -185,10 +189,13 @@ def get_management_tokens() -> list[str]:
185189
secret_tokens.append(
186190
secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSCURRENT")["SecretString"]
187191
)
188-
secret_tokens.append(
189-
secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSPREVIOUS")["SecretString"]
190-
)
192+
try:
193+
secret_tokens.append(
194+
secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSPREVIOUS")["SecretString"]
195+
)
196+
except Exception:
197+
logger.info("No previous management token version found")
191198
except ClientError as e:
192-
logger.warn(f"Unable to fetch {secret_id}. {e.response['Error']['Code']}: {e.response['Error']['Message']}")
199+
logger.warning(f"Unable to fetch {secret_id}. {e.response['Error']['Code']}: {e.response['Error']['Message']}")
193200

194201
return secret_tokens

lambda/repository/lambda_functions.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from lisapy.langchain import LisaOpenAIEmbeddings
2525
from models.domain_objects import IngestionType, RagDocument
2626
from repository.rag_document_repo import RagDocumentRepository
27-
from utilities.common_functions import api_wrapper, get_cert_path, get_id_token, get_username, retry_config
27+
from utilities.common_functions import api_wrapper, get_cert_path, get_groups, get_id_token, get_username, retry_config
2828
from utilities.exceptions import HTTPException
2929
from utilities.file_processing import process_record
3030
from utilities.validation import validate_model_name, ValidationError
@@ -186,22 +186,14 @@ def _get_embeddings_pipeline(model_name: str) -> Any:
186186

187187
@api_wrapper
188188
def list_all(event: dict, context: dict) -> List[Dict[str, Any]]:
189-
"""Return info on all available repositories.
190-
191-
Currently, there is no support for dynamic repositories so only a single OpenSearch repository
192-
is returned.
193-
"""
194-
195-
user_groups = json.loads(event["requestContext"]["authorizer"]["groups"]) or []
189+
"""Return info on all available repositories."""
190+
user_groups = get_groups(event)
196191
registered_repositories = get_registered_repositories()
197-
198-
return list(
199-
filter(lambda repository: user_has_group(user_groups, repository["allowedGroups"]), registered_repositories)
200-
)
192+
return [repo for repo in registered_repositories if user_has_group(user_groups, repo["allowedGroups"])]
201193

202194

203195
def user_has_group(user_groups: List[str], allowed_groups: List[str]) -> bool:
204-
"""Returns if user groups has at least one intersections with allowed groups.
196+
"""Returns if user groups has at least one intersection with allowed groups.
205197
206198
If allowed groups is empty this will return True.
207199
"""
@@ -290,8 +282,8 @@ def delete_document(event: dict, context: dict) -> Dict[str, Any]:
290282
path_params = event.get("pathParameters", {})
291283
repository_id = path_params.get("repositoryId")
292284

293-
query_string_params = event["queryStringParameters"]
294-
collection_id = query_string_params["collectionId"]
285+
query_string_params = event.get("queryStringParameters", {})
286+
collection_id = query_string_params.get("collectionId")
295287
document_id = query_string_params.get("documentId")
296288
document_name = query_string_params.get("documentName")
297289

lambda/utilities/common_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def get_session_id(event: dict) -> str:
360360

361361
def get_groups(event: Any) -> List[str]:
362362
"""Get user groups from event."""
363-
groups: List[str] = event.get("requestContext", {}).get("authorizer", {}).get("groups", [])
363+
groups: List[str] = json.loads(event.get("requestContext", {}).get("authorizer", {}).get("groups", "[]"))
364364
return groups
365365

366366

lib/core/api_deployment.ts

+9-2
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
limitations under the License.
1515
*/
1616

17-
import { CfnOutput, Stack, StackProps, Aws } from 'aws-cdk-lib';
17+
import { Aws, CfnOutput, Stack, StackProps } from 'aws-cdk-lib';
1818
import { Deployment, RestApi } from 'aws-cdk-lib/aws-apigateway';
1919
import { Construct } from 'constructs';
2020

2121
import { BaseProps } from '../schema';
22+
import { StringParameter } from 'aws-cdk-lib/aws-ssm';
2223

2324
type LisaApiDeploymentStackProps = {
2425
restApiId: string;
@@ -43,8 +44,14 @@ export class LisaApiDeploymentStack extends Stack {
4344
// https://github.com/aws/aws-cdk/issues/25582
4445
(deployment as any).resource.stageName = config.deploymentStage;
4546

47+
const api_url = `https://${restApiId}.execute-api.${this.region}.${Aws.URL_SUFFIX}/${config.deploymentStage}`;
48+
new StringParameter(this, 'LisaApiDeploymentStringParameter', {
49+
parameterName: `${config.deploymentPrefix}/${config.deploymentName}/${config.appName}/LisaApiUrl`,
50+
stringValue: api_url,
51+
description: 'API Gateway URL for LISA',
52+
});
4653
new CfnOutput(this, 'ApiUrl', {
47-
value: `https://${restApiId}.execute-api.${this.region}.${Aws.URL_SUFFIX}/${config.deploymentStage}`,
54+
value: api_url,
4855
description: 'API Gateway URL'
4956
});
5057
}
+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1+
# urllib3<2 // Provided by Lambda
2+
requests==2.32.3
13
cryptography==43.0.1
24
PyJWT==2.9.0
3-
requests==2.32.3
4-
urllib3<2
+3-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
boto3>=1.34.131
2-
botocore>=1.34.131
3-
urllib3<2
1+
# boto3>=1.34.131 // Provided by Lambda
2+
# botocore>=1.34.131 // Provided by Lambda
3+
# urllib3<2 // Provided by Lambda

lib/core/layers/fastapi/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
boto3==1.34.131
1+
# boto3==1.34.131 // Provided by Lambda
22
fastapi==0.111.0
33
mangum==0.17.0
44
pydantic==2.8.2

lib/models/model-api.ts

+6-5
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ export class ModelsApi extends Construct {
8686
StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/layerVersion/fastapi`),
8787
);
8888

89+
const lambdaLayers = [commonLambdaLayer, fastapiLambdaLayer];
8990
const restApi = RestApi.fromRestApiAttributes(this, 'RestApi', {
9091
restApiId: restApiId,
9192
rootResourceId: rootResourceId,
@@ -140,7 +141,7 @@ export class ModelsApi extends Construct {
140141
const createModelStateMachine = new CreateModelStateMachine(this, 'CreateModelWorkflow', {
141142
config: config,
142143
modelTable: modelTable,
143-
lambdaLayers: [commonLambdaLayer, fastapiLambdaLayer],
144+
lambdaLayers: lambdaLayers,
144145
role: stateMachinesLambdaRole,
145146
vpc: vpc,
146147
securityGroups: securityGroups,
@@ -155,7 +156,7 @@ export class ModelsApi extends Construct {
155156
const deleteModelStateMachine = new DeleteModelStateMachine(this, 'DeleteModelWorkflow', {
156157
config: config,
157158
modelTable: modelTable,
158-
lambdaLayers: [commonLambdaLayer, fastapiLambdaLayer],
159+
lambdaLayers: lambdaLayers,
159160
role: stateMachinesLambdaRole,
160161
vpc: vpc,
161162
securityGroups: securityGroups,
@@ -167,7 +168,7 @@ export class ModelsApi extends Construct {
167168
const updateModelStateMachine = new UpdateModelStateMachine(this, 'UpdateModelWorkflow', {
168169
config: config,
169170
modelTable: modelTable,
170-
lambdaLayers: [commonLambdaLayer, fastapiLambdaLayer],
171+
lambdaLayers: lambdaLayers,
171172
role: stateMachinesLambdaRole,
172173
vpc: vpc,
173174
securityGroups: securityGroups,
@@ -193,7 +194,7 @@ export class ModelsApi extends Construct {
193194
restApi,
194195
authorizer,
195196
'./lambda',
196-
[commonLambdaLayer, fastapiLambdaLayer],
197+
lambdaLayers,
197198
{
198199
name: 'handler',
199200
resource: 'models',
@@ -272,7 +273,7 @@ export class ModelsApi extends Construct {
272273
restApi,
273274
authorizer,
274275
'./lambda',
275-
[commonLambdaLayer],
276+
lambdaLayers,
276277
f,
277278
getDefaultRuntime(),
278279
vpc,

lib/rag/layer/requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
boto3>=1.34.131
2-
botocore>=1.34.131
1+
# boto3>=1.34.131 // Provided by Lambda
2+
# botocore>=1.34.131 // Provided by Lambda
3+
# urllib3<2 // Provided by Lambda
34
langchain==0.3.9
45
langchain-community==0.3.9
56
langchain-openai==0.2.11
@@ -9,4 +10,3 @@ psycopg2-binary==2.9.9
910
pypdf==4.3.1
1011
python-docx==1.1.0
1112
requests-aws4auth==1.2.3
12-
urllib3<2

lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
# List models
4040
"models",
4141
"v1/models",
42+
# Model Info
43+
"model/info",
44+
"v1/model/info",
4245
# Text completions
4346
"chat/completions",
4447
"v1/chat/completions",

lisa-sdk/lisapy/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# flake8: noqa
16-
from .main import Lisa
15+
from .api import LisaApi # noqa: F401
16+
from .main import LisaLlm # noqa: F401

lisa-sdk/lisapy/api.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, Optional, Union
16+
17+
from pydantic import BaseModel, Field
18+
from requests import Session
19+
20+
from .config import ConfigMixin
21+
from .doc import DocsMixin
22+
from .model import ModelMixin
23+
from .rag import RagMixin
24+
from .repository import RepositoryMixin
25+
from .session import SessionMixin
26+
27+
28+
class LisaApi(BaseModel, RepositoryMixin, ModelMixin, ConfigMixin, DocsMixin, RagMixin, SessionMixin):
29+
url: str = Field(..., description="REST API url for LiteLLM")
30+
headers: Optional[Dict[str, str]] = Field(None, description="Headers for request.")
31+
cookies: Optional[Dict[str, str]] = Field(None, description="Cookies for request.")
32+
verify: Optional[Union[str, bool]] = Field(None, description="Whether to verify SSL certificates.")
33+
timeout: int = Field(10, description="Timeout in minutes request.")
34+
_session: Session
35+
36+
def __init__(self, *args: Any, **kwargs: Any) -> None:
37+
super().__init__(*args, **kwargs)
38+
39+
self._session = Session()
40+
if self.headers:
41+
self._session.headers = self.headers # type: ignore
42+
if self.verify is not None:
43+
self._session.verify = self.verify
44+
if self.cookies:
45+
self._session.cookies = self.cookies # type: ignore

lisa-sdk/lisapy/common.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, Optional, Union
16+
17+
from requests import Session
18+
19+
20+
class BaseMixin:
21+
url: str
22+
headers: Optional[Dict[str, str]]
23+
cookies: Optional[Dict[str, str]]
24+
verify: Optional[Union[str, bool]]
25+
timeout: int
26+
_session: Session

lisa-sdk/lisapy/config.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, List
16+
17+
from .common import BaseMixin
18+
from .errors import parse_error
19+
20+
21+
class ConfigMixin(BaseMixin):
22+
"""Mixin for config-related operations."""
23+
24+
def get_configs(self, config_scope: str = "global") -> List[Dict]:
25+
response = self._session.get(f"{self.url}/configuration?configScope={config_scope}")
26+
if response.status_code == 200:
27+
json_configs: List[Dict] = response.json()
28+
return json_configs
29+
else:
30+
raise parse_error(response.status_code, response)

lisa-sdk/lisapy/doc.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .common import BaseMixin
16+
from .errors import parse_error
17+
18+
19+
class DocsMixin(BaseMixin):
20+
"""Mixin for doc-related operations."""
21+
22+
def list_docs(self) -> str:
23+
response = self._session.get(f"{self.url}/docs")
24+
if response.status_code == 200:
25+
html: str = response.text
26+
return html
27+
else:
28+
raise parse_error(response.status_code, response)

0 commit comments

Comments
 (0)