Skip to content

Commit

Permalink
[k8s] support to use custom gpu resource name if it's not nvidia.com/gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
nkwangleiGIT committed Nov 14, 2024
1 parent a2e670d commit 4e138a3
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 8 deletions.
2 changes: 1 addition & 1 deletion sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def make_deploy_resources_variables(
tpu_requested = True
k8s_resource_key = kubernetes_utils.TPU_RESOURCE_KEY
else:
k8s_resource_key = kubernetes_utils.GPU_RESOURCE_KEY
k8s_resource_key = kubernetes_utils.get_gpu_resource_key()

port_mode = network_utils.get_port_mode(None)

Expand Down
2 changes: 1 addition & 1 deletion sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
limits = pod_spec['spec']['containers'][0].get('resources',
{}).get('limits')
if limits is not None:
needs_gpus = limits.get(kubernetes_utils.GPU_RESOURCE_KEY, 0) > 0
needs_gpus = limits.get(kubernetes_utils.get_gpu_resource_key(), 0) > 0

# TPU pods provisioned on GKE use the default containerd runtime.
# Reference: https://cloud.google.com/kubernetes-engine/docs/how-to/migrate-containerd#overview # pylint: disable=line-too-long
Expand Down
28 changes: 24 additions & 4 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def detect_accelerator_resource(
nodes = get_kubernetes_nodes(context)
for node in nodes:
cluster_resources.update(node.status.allocatable.keys())
has_accelerator = (GPU_RESOURCE_KEY in cluster_resources or
has_accelerator = (get_gpu_resource_key() in cluster_resources or
TPU_RESOURCE_KEY in cluster_resources)

return has_accelerator, cluster_resources
Expand Down Expand Up @@ -2233,10 +2233,11 @@ def get_node_accelerator_count(attribute_dict: dict) -> int:
Number of accelerators allocated or available from the node. If no
resource is found, it returns 0.
"""
assert not (GPU_RESOURCE_KEY in attribute_dict and
gpuResourceKey = get_gpu_resource_key()
assert not (gpuResourceKey in attribute_dict and
TPU_RESOURCE_KEY in attribute_dict)
if GPU_RESOURCE_KEY in attribute_dict:
return int(attribute_dict[GPU_RESOURCE_KEY])
if gpuResourceKey in attribute_dict:
return int(attribute_dict[gpuResourceKey])
elif TPU_RESOURCE_KEY in attribute_dict:
return int(attribute_dict[TPU_RESOURCE_KEY])
return 0
Expand Down Expand Up @@ -2395,3 +2396,22 @@ def process_skypilot_pods(
num_pods = len(cluster.pods)
cluster.resources_str = f'{num_pods}x {cluster.resources}'
return list(clusters.values()), jobs_controllers, serve_controllers

def get_gpu_resource_key():
"""Get the GPU resource name to use in kubernetes.
The function first checks for an environment variable.
If defined, it uses its value; otherwise, it returns the default value.
Args:
name (str): Default GPU resource name, default is "nvidia.com/gpu".
Returns:
str: The selected GPU resource name.
"""
# Retrieve GPU resource name from environment variable, if set.
# E.g., can be nvidia.com/gpu-h100, amd.com/gpu etc.
custom_name = os.getenv('CUSTOM_GPU_RESOURCE_KEY')

# If the environment variable is not defined, return the default name
if custom_name is None:
return GPU_RESOURCE_KEY

return custom_name
4 changes: 2 additions & 2 deletions sky/utils/kubernetes/gpu_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def label():
# Get the list of nodes with GPUs
gpu_nodes = []
for node in nodes:
if kubernetes_utils.GPU_RESOURCE_KEY in node.status.capacity:
if kubernetes_utils.get_gpu_resource_key() in node.status.capacity:
gpu_nodes.append(node)

print(f'Found {len(gpu_nodes)} GPU nodes in the cluster')
Expand Down Expand Up @@ -142,7 +142,7 @@ def label():
if len(gpu_nodes) == 0:
print('No GPU nodes found in the cluster. If you have GPU nodes, '
'please ensure that they have the label '
f'`{kubernetes_utils.GPU_RESOURCE_KEY}: <number of GPUs>`')
f'`{kubernetes_utils.get_gpu_resource_key()}: <number of GPUs>`')
else:
print('GPU labeling started - this may take 10 min or more to complete.'
'\nTo check the status of GPU labeling jobs, run '
Expand Down

0 comments on commit 4e138a3

Please sign in to comment.