Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preliminary Vast AI support #4365

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
OCI = clouds.OCI
Paperspace = clouds.Paperspace
RunPod = clouds.RunPod
Vast = clouds.Vast
Vsphere = clouds.Vsphere
Fluidstack = clouds.Fluidstack
optimize = Optimizer.optimize
Expand All @@ -149,6 +150,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
'OCI',
'Paperspace',
'RunPod',
'Vast',
'SCP',
'Vsphere',
'Fluidstack',
Expand Down
29 changes: 29 additions & 0 deletions sky/adaptors/vast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Vast cloud adaptor."""

import functools

_vast_sdk = None


def import_package(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
global _vast_sdk

if _vast_sdk is None:
try:
import vastai_sdk as _vast # pylint: disable=import-outside-toplevel
_vast_sdk = _vast.VastAI()
except ImportError:
raise ImportError('Fail to import dependencies for vast.'
'Try pip install "skypilot[vast]"') from None
kristopolous marked this conversation as resolved.
Show resolved Hide resolved
return func(*args, **kwargs)

return wrapper


@import_package
def vast():
"""Return the vast package."""
return _vast_sdk
14 changes: 14 additions & 0 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sky.adaptors import ibm
from sky.adaptors import kubernetes
from sky.adaptors import runpod
from sky.adaptors import vast
from sky.provision.fluidstack import fluidstack_utils
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.provision.lambda_cloud import lambda_utils
Expand Down Expand Up @@ -473,6 +474,19 @@ def setup_runpod_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
return configure_ssh_info(config)


def setup_vast_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
"""Sets up SSH authentication for Vast.
- Generates a new SSH key pair if one does not exist.
- Adds the public SSH key to the user's Vast account.
"""
_, public_key_path = get_or_generate_keys()
with open(public_key_path, 'r', encoding='UTF-8') as pub_key_file:
public_key = pub_key_file.read().strip()
vast.vast().create_ssh_key(ssh_key=public_key)
config['auth']['ssh_public_key'] = PUBLIC_SSH_KEY_PATH
return configure_ssh_info(config)


def setup_fluidstack_authentication(config: Dict[str, Any]) -> Dict[str, Any]:

get_or_generate_keys()
Expand Down
2 changes: 2 additions & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,8 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str):
config = auth.setup_ibm_authentication(config)
elif isinstance(cloud, clouds.RunPod):
config = auth.setup_runpod_authentication(config)
elif isinstance(cloud, clouds.Vast):
config = auth.setup_vast_authentication(config)
elif isinstance(cloud, clouds.Fluidstack):
config = auth.setup_fluidstack_authentication(config)
else:
Expand Down
1 change: 1 addition & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def _get_cluster_config_template(cloud):
clouds.RunPod: 'runpod-ray.yml.j2',
clouds.Kubernetes: 'kubernetes-ray.yml.j2',
clouds.Vsphere: 'vsphere-ray.yml.j2',
clouds.Vast: 'vast-ray.yml.j2',
clouds.Fluidstack: 'fluidstack-ray.yml.j2'
}
return cloud_to_template[type(cloud)]
Expand Down
2 changes: 2 additions & 0 deletions sky/clouds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sky.clouds.paperspace import Paperspace
from sky.clouds.runpod import RunPod
from sky.clouds.scp import SCP
from sky.clouds.vast import Vast
from sky.clouds.vsphere import Vsphere

__all__ = [
Expand All @@ -37,6 +38,7 @@
'Paperspace',
'SCP',
'RunPod',
'Vast',
'OCI',
'Vsphere',
'Kubernetes',
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/service_catalog/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
CATALOG_SCHEMA_VERSION = 'v5'
CATALOG_DIR = '~/.sky/catalogs'
ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
'kubernetes', 'runpod', 'vsphere', 'cudo', 'fluidstack',
'kubernetes', 'runpod', 'vast', 'vsphere', 'cudo', 'fluidstack',
'paperspace')
67 changes: 67 additions & 0 deletions sky/clouds/service_catalog/data_fetchers/fetch_vast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""A script that generates the Vast Cloud catalog. """

# pylint: disable=assignment-from-no-return
import csv
import json
import re
import sys

from vastai_sdk import VastAI


def create_instance_type(obj):
stubify = lambda x: re.sub(r'\s', '_', x)
return '{}x-{}-{}'.format(obj['num_gpus'], stubify(obj['gpu_name']),
obj['cpu_cores'])


def dot_get(d, key):
for k in key.split('.'):
d = d[k]
return d


# InstanceType and gpuInfo are basically just stubs
# so that the dictwriter is happy without weird
# code.
mapped_keys = (('gpu_name', 'InstanceType'), ('gpu_name', 'AcceleratorName'),
('num_gpus', 'AcceleratorCount'), ('cpu_cores', 'vCPUs'),
('gpu_total_ram', 'MemoryGiB'), ('search.totalHour', 'Price'),
('geolocation', 'Region'), ('gpu_name', 'GpuInfo'),
('search.totalHour', 'SpotPrice'))
writer = csv.DictWriter(sys.stdout, fieldnames=[x[1] for x in mapped_keys])
writer.writeheader()

offerList = VastAI().search_offers(limit=10000)
for offer in offerList:
entry = {}
for ours, theirs in mapped_keys:
field = dot_get(offer, ours)
if 'Price' in theirs:
field = '{:.2f}'.format(field)
entry[theirs] = field

entry['InstanceType'] = create_instance_type(offer)

# the documentation says
# "{'gpus': [{
# 'name': 'v100',
# 'manufacturer': 'nvidia',
# 'count': 8.0,
# 'memoryinfo': {'sizeinmib': 16384}
# }],
# 'totalgpumemoryinmib': 16384}",
# we can do that.
entry['MemoryGiB'] /= 1024
entry['GpuInfo'] = json.dumps({
'Gpus': [{
'Name': offer['gpu_name'],
'Count': offer['num_gpus'],
'MemoryInfo': {
'SizeInMiB': offer['gpu_total_ram']
}
}],
'TotalGpuMemoryInMiB': offer['gpu_total_ram']
}).replace('"', "'") # pylint: disable=invalid-string-quote

writer.writerow(entry)
104 changes: 104 additions & 0 deletions sky/clouds/service_catalog/vast_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
""" Vast | Catalog

This module loads the service catalog file and can be used to
query instance types and pricing information for Vast.ai.
"""

import typing
from typing import Dict, List, Optional, Tuple, Union

from sky.clouds.service_catalog import common
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
from sky.clouds import cloud

_df = common.read_catalog('vast/vms.csv')


def instance_type_exists(instance_type: str) -> bool:
return common.instance_type_exists_impl(_df, instance_type)


def validate_region_zone(
region: Optional[str],
zone: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
if zone is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Vast does not support zones.')
return common.validate_region_zone_impl('vast', _df, region, zone)


def get_hourly_cost(instance_type: str,
use_spot: bool = False,
region: Optional[str] = None,
zone: Optional[str] = None) -> float:
"""Returns the cost, or the cheapest cost among all zones for spot."""
if zone is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Vast does not support zones.')
return common.get_hourly_cost_impl(_df, instance_type, use_spot, region,
zone)


def get_vcpus_mem_from_instance_type(
instance_type: str) -> Tuple[Optional[float], Optional[float]]:
return common.get_vcpus_mem_from_instance_type_impl(_df, instance_type)


def get_default_instance_type(cpus: Optional[str] = None,
memory: Optional[str] = None,
disk_tier: Optional[str] = None) -> Optional[str]:
del disk_tier
# NOTE: After expanding catalog to multiple entries, you may
# want to specify a default instance type or family.
return common.get_instance_type_for_cpus_mem_impl(_df, cpus, memory)


def get_accelerators_from_instance_type(
instance_type: str) -> Optional[Dict[str, Union[int, float]]]:
return common.get_accelerators_from_instance_type_impl(_df, instance_type)


def get_instance_type_for_accelerator(
acc_name: str,
acc_count: int,
cpus: Optional[str] = None,
memory: Optional[str] = None,
use_spot: bool = False,
region: Optional[str] = None,
zone: Optional[str] = None) -> Tuple[Optional[List[str]], List[str]]:
"""Returns a list of instance types that have the given accelerator."""
if zone is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Vast does not support zones.')
return common.get_instance_type_for_accelerator_impl(df=_df,
acc_name=acc_name,
acc_count=acc_count,
cpus=cpus,
memory=memory,
use_spot=use_spot,
region=region,
zone=zone)


def get_region_zones_for_instance_type(instance_type: str,
use_spot: bool) -> List['cloud.Region']:
df = _df[_df['InstanceType'] == instance_type]
return common.get_region_zones(df, use_spot)


# TODO: this differs from the fluffy catalog version
def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
quantity_filter: Optional[int],
case_sensitive: bool = True,
all_regions: bool = False,
require_price: bool = True) -> Dict[str, List[common.InstanceTypeInfo]]:
"""Returns all instance types in Vast offering GPUs."""
del require_price # Unused.
return common.list_accelerators_impl('Vast', _df, gpus_only, name_filter,
region_filter, quantity_filter,
case_sensitive, all_regions)
Loading
Loading