Skip to content

Commit

Permalink
Add filter function to list
Browse files Browse the repository at this point in the history
  • Loading branch information
omer-dayan committed Dec 19, 2024
1 parent 48eb474 commit e04b46e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ def list(self, path: str, pattern: str) -> List[str]:
all_files = runai_list(self.streamer, path)
return [file for file in all_files if fnmatch.fnmatch(file, pattern)]

def copy(self, src_path : str, dst_path : str) -> None:
return runai_copy(self.streamer, src_path, dst_path)
def copy(self, src_path: str, dst_path: str) -> None:
return runai_copy(self.streamer, src_path, dst_path)
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def test_limited_memory_cap(self, mock_get_memory_mode, mock_getenv):
id_to_results[id]["expected_text"],
)

@patch("runai_model_streamer.file_streamer.file_streamer.runai_list")
def test_list_pattern(self, mock_runai_list):
mock_runai_list.return_value = ["a.safetensors", "b.py", "c.safetensors"]
with FileStreamer() as fs:
filtered = fs.list("test", "*.safetensors")
self.assertCountEqual(filtered, ["c.safetensors", "a.safetensors"])

def tearDown(self):
shutil.rmtree(self.temp_dir)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,14 @@ def runai_response(streamer: t_streamer) -> Optional[int]:
def runai_response_str(response_code: int) -> str:
return dll.fn_runai_response_str(response_code)


def runai_list(streamer: t_streamer, path: str) -> List[str]:
keys = ctypes.POINTER(ctypes.c_char_p)()
count = ctypes.c_size_t()

# Call the `runai_list` function
error_code = dll.fn_runai_list(
streamer,
path.encode('utf-8'),
ctypes.byref(keys),
ctypes.byref(count)
streamer, path.encode("utf-8"), ctypes.byref(keys), ctypes.byref(count)
)

if error_code != SUCCESS_ERROR_CODE:
Expand All @@ -100,7 +98,7 @@ def runai_list(streamer: t_streamer, path: str) -> List[str]:

# Convert the result to a Python list
object_list = [
ctypes.cast(keys[i], ctypes.c_char_p).value.decode('utf-8')
ctypes.cast(keys[i], ctypes.c_char_p).value.decode("utf-8")
for i in range(count.value)
]

Expand All @@ -113,8 +111,11 @@ def runai_list(streamer: t_streamer, path: str) -> List[str]:

return object_list


def runai_copy(streamer: t_streamer, src_path: str, dst_path: str) -> None:
error_code = dll.fn_runai_read_object_to_file(streamer, src_path.encode('utf-8'), dst_path.encode('utf-8'))
error_code = dll.fn_runai_read_object_to_file(
streamer, src_path.encode("utf-8"), dst_path.encode("utf-8")
)
if error_code != SUCCESS_ERROR_CODE:
raise Exception(
f"Could not download in libstreamer due to: {runai_response_str(error_code)}"
Expand Down

0 comments on commit e04b46e

Please sign in to comment.