Skip to content

Commit

Permalink
Add filter function to list
Browse files Browse the repository at this point in the history
  • Loading branch information
omer-dayan committed Dec 19, 2024
1 parent 48eb474 commit 65c5e61
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def request_ready_chunks(self) -> Iterator:
self.current_request_chunks[:relative_index]
)

def list(self, path: str, pattern: str) -> List[str]:
def list(self, path: str, pattern: str = "*.*") -> List[str]:
all_files = runai_list(self.streamer, path)
return [file for file in all_files if fnmatch.fnmatch(file, pattern)]

def copy(self, src_path : str, dst_path : str) -> None:
return runai_copy(self.streamer, src_path, dst_path)
def copy(self, src_path: str, dst_path: str) -> None:
return runai_copy(self.streamer, src_path, dst_path)
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def test_limited_memory_cap(self, mock_get_memory_mode, mock_getenv):
id_to_results[id]["expected_text"],
)

@patch("runai_model_streamer.file_streamer.file_streamer.runai_list")
def test_list_pattern(self, mock_runai_list):
mock_runai_list.return_value = ["a.safetensors", "b.py", "c.safetensors"]
with FileStreamer() as fs:
filtered = fs.list("test", "*.safetensors")
self.assertCountEqual(filtered, ["c.safetensors", "a.safetensors"])

def tearDown(self):
shutil.rmtree(self.temp_dir)

Expand Down

0 comments on commit 65c5e61

Please sign in to comment.