Skip to content

Commit

Permalink
Add support for local mode via Ollama
Browse files Browse the repository at this point in the history
Adds support for running against local models by supporting the Ollama
API in addition to the Qiskit Code Assistant service API.

This allows users to input an Ollama API URL instead of a Qiskit Code
Assistant service URL and the server extension will detect which API
is set and call the correct endpoints.
  • Loading branch information
ajbozarth committed Oct 8, 2024
1 parent b6a76f1 commit b0a8cd8
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 78 deletions.
206 changes: 144 additions & 62 deletions qiskit_code_assistant_jupyterlab/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from jupyter_server.utils import url_path_join
from qiskit_ibm_runtime import QiskitRuntimeService

runtime_configs = {"service_url": "http://localhost", "api_token": ""}
runtime_configs = {
"service_url": "http://localhost",
"api_token": "",
"is_ollama": False,
}


def update_token(token):
Expand Down Expand Up @@ -55,11 +59,26 @@ def init_token():


def get_header():
return {
header = {
"Accept": "application/json",
"Content-Type": "application/json",
"X-Caller": "qiskit-code-assistant-jupyterlab",
"Authorization": f"Bearer {runtime_configs['api_token']}",
}
if not runtime_configs["is_ollama"]:
header["Authorization"] = f"Bearer {runtime_configs['api_token']}"
return header


def convert_ollama(model):
return {
"_id": model["model"],
"disclaimer": {"accepted": "true"},
"display_name": model["name"],
"doc_link": "",
"license": {"name": "", "link": ""},
"model_id": model["model"],
"prompt_type": 1,
"token_limit": 255
}


Expand All @@ -74,13 +93,19 @@ def post(self):

runtime_configs["service_url"] = json_payload["url"]

self.finish(json.dumps({"url": runtime_configs["service_url"]}))
try:
r = requests.get(url_path_join(runtime_configs["service_url"]), headers=get_header())
# TODO: Replace with a check against the QCA service instead
runtime_configs["is_ollama"] = ("Ollama is running" in r.text)
finally:
self.finish(json.dumps({"url": runtime_configs["service_url"]}))


class TokenHandler(APIHandler):
@tornado.web.authenticated
def get(self):
self.finish(json.dumps({"success": (runtime_configs["api_token"] != "")}))
self.finish(json.dumps({"success": (runtime_configs["api_token"] != ""
or runtime_configs["is_ollama"])}))

@tornado.web.authenticated
def post(self):
Expand All @@ -94,98 +119,155 @@ def post(self):
class ModelsHandler(APIHandler):
@tornado.web.authenticated
def get(self):
url = url_path_join(runtime_configs["service_url"], "models")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_ollama"]:
url = url_path_join(runtime_configs["service_url"], "api", "tags")
models = []
try:
r = requests.get(url, headers=get_header())
r.raise_for_status()

if r.ok:
ollama_models = r.json()["models"]
models = list(map(convert_ollama, ollama_models))
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps({"models": models}))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "models")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class ModelHandler(APIHandler):
@tornado.web.authenticated
def get(self, id):
url = url_path_join(runtime_configs["service_url"], "model", id)

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_ollama"]:
self.set_status(501, "Not implemented")
self.finish()
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "model", id)

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class DisclaimerHandler(APIHandler):
@tornado.web.authenticated
def get(self, id):
url = url_path_join(runtime_configs["service_url"], "model", id, "disclaimer")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_ollama"]:
self.set_status(501, "Not implemented")
self.finish()
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "model", id, "disclaimer")

try:
r = requests.get(url, headers=get_header())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class DisclaimerAcceptanceHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
url = url_path_join(
runtime_configs["service_url"], "disclaimer", id, "acceptance"
)

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_ollama"]:
self.set_status(501, "Not implemented")
self.finish()
else:
self.finish(json.dumps(r.json()))
url = url_path_join(
runtime_configs["service_url"], "disclaimer", id, "acceptance"
)

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class PromptHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
url = url_path_join(runtime_configs["service_url"], "model", id, "prompt")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_ollama"]:
url = url_path_join(runtime_configs["service_url"], "api", "generate")
result = {}
try:
r = requests.post(url,
headers=get_header(),
json={
"model": id,
"prompt": self.get_json_body()["input"],
"stream": False
})
r.raise_for_status()

if r.ok:
ollama_response = r.json()
result = {
"results": [{"generated_text": ollama_response["response"]}],
"prompt_id": ollama_response["created_at"],
"created_at": ollama_response["created_at"]
}
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(result))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "model", id, "prompt")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


class PromptAcceptanceHandler(APIHandler):
@tornado.web.authenticated
def post(self, id):
url = url_path_join(runtime_configs["service_url"], "prompt", id, "acceptance")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
if runtime_configs["is_ollama"]:
self.finish(json.dumps({"success": "true"}))
else:
self.finish(json.dumps(r.json()))
url = url_path_join(runtime_configs["service_url"], "prompt", id, "acceptance")

try:
r = requests.post(url, headers=get_header(), json=self.get_json_body())
r.raise_for_status()
except requests.exceptions.HTTPError as err:
self.set_status(err.response.status_code)
self.finish(json.dumps(err.response.json()))
else:
self.finish(json.dumps(r.json()))


def setup_handlers(web_app):
host_pattern = ".*$"
id_regex = r"(?P<id>[\w\-]+)"
id_regex = r"(?P<id>[\w\-\_\.\:]+)" # valid chars: alphanum | "-" | "_" | "." | ":"
base_url = url_path_join(web_app.settings["base_url"], "qiskit-code-assistant")

handlers = [
Expand Down
28 changes: 12 additions & 16 deletions src/service/autocomplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { getModel, postModelPrompt } from './api';
import { postModelPrompt } from './api';
import { showDisclaimer } from './disclaimer';
import { getCurrentModel } from './modelHandler';
import { checkAPIToken } from './token';
Expand Down Expand Up @@ -51,25 +51,21 @@ export async function autoComplete(text: string): Promise<ICompletionReturn> {
const requestText = text.slice(startingOffset, text.length);
const model = getCurrentModel();

return await getModel(model?._id || '')
.then(async model => {
if (model.disclaimer?.accepted) {
if (model === undefined) {
console.error('Failed to send prompt', 'No model selected');
return emptyReturn;
} else if (model.disclaimer?.accepted) {
return await promptPromise(model._id, requestText);
} else {
return await showDisclaimer(model._id).then(async accepted => {
if (accepted) {
return await promptPromise(model._id, requestText);
} else {
return await showDisclaimer(model._id).then(async accepted => {
if (accepted) {
return await promptPromise(model._id, requestText);
} else {
console.error('Disclaimer not accepted');
return emptyReturn;
}
});
console.error('Disclaimer not accepted');
return emptyReturn;
}
})
.catch(reason => {
console.error('Failed to send prompt', reason);
return emptyReturn;
});
}
})
.catch(reason => {
console.error('Failed to send prompt', reason);
Expand Down

0 comments on commit b0a8cd8

Please sign in to comment.