Skip to content

Commit

Permalink
Merge pull request #28 from quantile-development/bug/azure-state-client
Browse files Browse the repository at this point in the history
Bug/azure state client
  • Loading branch information
BernardWez authored Dec 5, 2023
2 parents 90b1255 + 17a51be commit d75f301
Showing 1 changed file with 51 additions and 1 deletion.
52 changes: 51 additions & 1 deletion elx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def params(self) -> Dict[str, Any]:
"client": self.client,
}

def has_existing_state(self, state_file_name: str) -> bool:
"""
Checks for a pre-existing state file.
Args:
state_file_name (str): The name of the state file to load.
Returns:
bool: Boolean flag to indicate whether there is a pre-existing state file.
"""
raise NotImplementedError


class S3StateClient(StateClient):
"""
Expand Down Expand Up @@ -61,6 +73,32 @@ def client(self):
os.environ["AZURE_STORAGE_CONNECTION_STRING"]
)

@property
def container_name(self) -> str:
"""
Gives the container name where state files are stored in Azure Blob Storage.
Returns:
str: Name of the container.
"""
return self.base_path.replace("azure://", "")

def has_existing_state(self, state_file_name: str) -> bool:
"""
Checks for a pre-existing state file.
Args:
state_file_name (str): The name of the state file to load.
Returns:
bool: Boolean flag to indicate whether there is a pre-existing state file.
"""
# Get container client where the state file would be located
container = self.client.get_container_client(container=self.container_name)

# Check if state file exists
return container.get_blob_client(blob=state_file_name).exists()


class GCSStateClient(StateClient):
"""
Expand Down Expand Up @@ -92,6 +130,18 @@ class LocalStateClient(StateClient):
def params(self) -> dict:
return {}

def has_existing_state(self, state_file_name: str) -> bool:
"""
Checks for a pre-existing state file.
Args:
state_file_name (str): The name of the state file to load.
Returns:
bool: Boolean flag to indicate whether there is a pre-existing state file.
"""
return Path(f"{self.base_path}/{state_file_name}").exists()


def state_client_factory(base_path: str) -> StateClient:
if base_path.startswith("s3://"):
Expand Down Expand Up @@ -123,7 +173,7 @@ def load(self, state_file_name: str) -> dict:
Returns:
dict: The contents of the state file.
"""
if not Path(f"{self.base_path}/{state_file_name}").exists():
if not self.state_client.has_existing_state(state_file_name):
return {}

with open(
Expand Down

0 comments on commit d75f301

Please sign in to comment.