diff --git a/elx/state.py b/elx/state.py index e1007b6..07c62ec 100644 --- a/elx/state.py +++ b/elx/state.py @@ -31,6 +31,18 @@ def params(self) -> Dict[str, Any]: "client": self.client, } + def has_existing_state(self, state_file_name: str) -> bool: + """ + Checks for a pre-existing state file. + + Args: + state_file_name (str): The name of the state file to load. + + Returns: + bool: Boolean flag to indicate whether there is a pre-existing state file. + """ + raise NotImplementedError + class S3StateClient(StateClient): """ @@ -61,6 +73,32 @@ def client(self): os.environ["AZURE_STORAGE_CONNECTION_STRING"] ) + @property + def container_name(self) -> str: + """ + Gives the container name where state files are stored in Azure Blob Storage. + + Returns: + str: Name of the container. + """ + return self.base_path.replace("azure://", "") + + def has_existing_state(self, state_file_name: str) -> bool: + """ + Checks for a pre-existing state file. + + Args: + state_file_name (str): The name of the state file to load. + + Returns: + bool: Boolean flag to indicate whether there is a pre-existing state file. + """ + # Get container client where the state file would be located + container = self.client.get_container_client(container=self.container_name) + + # Check if state file exists + return container.get_blob_client(blob=state_file_name).exists() + class GCSStateClient(StateClient): """ @@ -92,6 +130,18 @@ class LocalStateClient(StateClient): def params(self) -> dict: return {} + def has_existing_state(self, state_file_name: str) -> bool: + """ + Checks for a pre-existing state file. + + Args: + state_file_name (str): The name of the state file to load. + + Returns: + bool: Boolean flag to indicate whether there is a pre-existing state file. + """ + return Path(f"{self.base_path}/{state_file_name}").exists() + def state_client_factory(base_path: str) -> StateClient: if base_path.startswith("s3://"): @@ -123,7 +173,7 @@ def load(self, state_file_name: str) -> dict: Returns: dict: The contents of the state file. """ - if not Path(f"{self.base_path}/{state_file_name}").exists(): + if not self.state_client.has_existing_state(state_file_name): return {} with open(