From 65c5e612260213f49686fcfd324ba5917aa2559e Mon Sep 17 00:00:00 2001 From: OmerD Date: Thu, 19 Dec 2024 15:40:19 +0000 Subject: [PATCH] Add filter function to list --- .../runai_model_streamer/file_streamer/file_streamer.py | 6 +++--- .../file_streamer/tests/test_file_streamer.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/py/runai_model_streamer/runai_model_streamer/file_streamer/file_streamer.py b/py/runai_model_streamer/runai_model_streamer/file_streamer/file_streamer.py index 889d4c1..f9f0c12 100644 --- a/py/runai_model_streamer/runai_model_streamer/file_streamer/file_streamer.py +++ b/py/runai_model_streamer/runai_model_streamer/file_streamer/file_streamer.py @@ -102,9 +102,9 @@ def request_ready_chunks(self) -> Iterator: self.current_request_chunks[:relative_index] ) - def list(self, path: str, pattern: str) -> List[str]: + def list(self, path: str, pattern: str = "*.*") -> List[str]: all_files = runai_list(self.streamer, path) return [file for file in all_files if fnmatch.fnmatch(file, pattern)] - def copy(self, src_path : str, dst_path : str) -> None: - return runai_copy(self.streamer, src_path, dst_path) \ No newline at end of file + def copy(self, src_path: str, dst_path: str) -> None: + return runai_copy(self.streamer, src_path, dst_path) diff --git a/py/runai_model_streamer/runai_model_streamer/file_streamer/tests/test_file_streamer.py b/py/runai_model_streamer/runai_model_streamer/file_streamer/tests/test_file_streamer.py index 8f622a8..c4f76d8 100644 --- a/py/runai_model_streamer/runai_model_streamer/file_streamer/tests/test_file_streamer.py +++ b/py/runai_model_streamer/runai_model_streamer/file_streamer/tests/test_file_streamer.py @@ -99,6 +99,13 @@ def test_limited_memory_cap(self, mock_get_memory_mode, mock_getenv): id_to_results[id]["expected_text"], ) + @patch("runai_model_streamer.file_streamer.file_streamer.runai_list") + def test_list_pattern(self, mock_runai_list): + mock_runai_list.return_value = ["a.safetensors", "b.py", "c.safetensors"] + with FileStreamer() as fs: + filtered = fs.list("test", "*.safetensors") + self.assertCountEqual(filtered, ["c.safetensors", "a.safetensors"]) + def tearDown(self): shutil.rmtree(self.temp_dir)