-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathmodels.py
67 lines (51 loc) · 1.95 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import ollama
from tqdm import tqdm
def __pull_model(name: str) -> None:
current_digest, bars = "", {}
for progress in ollama.pull(name, stream=True):
digest = progress.get("digest", "")
if digest != current_digest and current_digest in bars:
bars[current_digest].close()
if not digest:
print(progress.get("status"))
continue
if digest not in bars and (total := progress.get("total")):
bars[digest] = tqdm(
total=total, desc=f"pulling {digest[7:19]}", unit="B", unit_scale=True
)
if completed := progress.get("completed"):
bars[digest].update(completed - bars[digest].n)
current_digest = digest
def __is_model_available_locally(model_name: str) -> bool:
try:
ollama.show(model_name)
return True
except ollama.ResponseError:
return False
def get_list_of_models() -> list[str]:
"""
Retrieves a list of available models from the Ollama repository.
Returns:
list[str]: A list of model names available in the Ollama repository.
"""
return [model["name"] for model in ollama.list()["models"]]
def check_if_model_is_available(model_name: str) -> None:
"""
Ensures that the specified model is available locally.
If the model is not available, it attempts to pull it from the Ollama repository.
Args:
model_name (str): The name of the model to check.
Raises:
ollama.ResponseError: If there is an issue with pulling the model from the repository.
"""
try:
available = __is_model_available_locally(model_name)
except Exception:
raise Exception("Unable to communicate with the Ollama service")
if not available:
try:
__pull_model(model_name)
except Exception:
raise Exception(
f"Unable to find model '{model_name}', please check the name and try again."
)