Skip to content

WIP: Add support for reading and writing .wfdb archive files #541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
213 changes: 213 additions & 0 deletions tests/test_archive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import os
import tempfile
import zipfile

import numpy as np
import pytest

from wfdb import rdrecord, wrsamp
from wfdb.io.archive import WFDBArchive

np.random.seed(1234)


@pytest.fixture
def temp_record():
"""
Create a temporary WFDB record and archive for testing.

This fixture generates a synthetic 2-channel signal, writes it to a temporary
directory using `wrsamp`, then creates an uncompressed `.wfdb` archive (ZIP container)
containing the `.hea` and `.dat` files. The archive is used to test read/write
round-trip support for WFDB archives.

Yields
------
dict
A dictionary containing:
- 'record_name': Path to the record base name (without extension).
- 'archive_path': Full path to the created `.wfdb` archive.
- 'original_signal': The original NumPy array of the signal.
- 'fs': The sampling frequency.
"""
with tempfile.TemporaryDirectory() as tmpdir:
record_basename = "testrecord"
fs = 250
sig_len = 1000
sig = (np.random.randn(sig_len, 2) * 1000).astype(np.float32)

# Write into tmpdir with record name only
wrsamp(
record_name=record_basename,
fs=fs,
units=["mV", "mV"],
sig_name=["I", "II"],
p_signal=sig,
fmt=["24", "24"],
adc_gain=[200.0, 200.0],
baseline=[0, 0],
write_dir=tmpdir,
)

# Construct full paths for archive creation
hea_path = os.path.join(tmpdir, record_basename + ".hea")
dat_path = os.path.join(tmpdir, record_basename + ".dat")
archive_path = os.path.join(tmpdir, record_basename + ".wfdb")

with WFDBArchive(record_name=record_basename, mode="w") as archive:
archive.create_archive(
file_list=[hea_path, dat_path],
output_path=archive_path,
)

try:
yield {
"record_name": os.path.join(tmpdir, record_basename),
"archive_path": archive_path,
"original_signal": sig,
"fs": fs,
}
finally:
# Clean up any open archive handles
from wfdb.io.archive import _archive_cache

for archive in _archive_cache.values():
if archive is not None:
archive.close()
_archive_cache.clear()


def test_wfdb_archive_inline_round_trip():
"""
There are two ways of creating an archive:

1. Inline archive creation via wrsamp(..., wfdb_archive=...)
This creates the .hea and .dat files directly inside the archive as part of the record writing step.

2. Two-step creation via wrsamp(...) followed by WFDBArchive.create_archive(...)
This writes regular WFDB files to disk, which are then added to an archive container afterward.

Test round-trip read/write using inline archive creation via `wrsamp(..., wfdb_archive=...)`.
"""
with tempfile.TemporaryDirectory() as tmpdir:
record_basename = "testrecord"
record_path = os.path.join(tmpdir, record_basename)
archive_path = record_path + ".wfdb"
fs = 250
sig_len = 1000
sig = (np.random.randn(sig_len, 2) * 1000).astype(np.float32)

# Create archive inline using context manager
with WFDBArchive(record_path, mode="w") as wfdb_archive:
wrsamp(
record_name=record_basename,
fs=fs,
units=["mV", "mV"],
sig_name=["I", "II"],
p_signal=sig,
fmt=["24", "24"],
adc_gain=[200.0, 200.0],
baseline=[0, 0],
write_dir=tmpdir,
wfdb_archive=wfdb_archive,
)

assert os.path.exists(archive_path), "Archive was not created"

# Read back from archive
record = rdrecord(archive_path)

try:
assert record.fs == fs
assert record.n_sig == 2
assert record.p_signal.shape == sig.shape

# Add tolerance to account for loss of precision during archive round-trip
np.testing.assert_allclose(
record.p_signal, sig, rtol=1e-2, atol=3e-3
)
finally:
# Ensure we close the archive after reading
if (
hasattr(record, "wfdb_archive")
and record.wfdb_archive is not None
):
record.wfdb_archive.close()


def test_wfdb_archive_round_trip(temp_record):
record_name = temp_record["record_name"]
archive_path = temp_record["archive_path"]
original_signal = temp_record["original_signal"]
fs = temp_record["fs"]

assert os.path.exists(archive_path), "Archive was not created"

record = rdrecord(archive_path)

assert record.fs == fs
assert record.n_sig == 2
assert record.p_signal.shape == original_signal.shape

# Add tolerance to account for loss of precision during archive round-trip
np.testing.assert_allclose(
record.p_signal, original_signal, rtol=1e-2, atol=3e-3
)


def test_archive_read_subset_channels(temp_record):
"""
Test reading a subset of channels from an archive.
"""
archive_path = temp_record["archive_path"]
original_signal = temp_record["original_signal"]

record = rdrecord(archive_path, channels=[1])

assert record.n_sig == 1
assert record.p_signal.shape[0] == original_signal.shape[0]

# Add tolerance to account for loss of precision during archive round-trip
np.testing.assert_allclose(
record.p_signal[:, 0], original_signal[:, 1], rtol=1e-2, atol=3e-3
)


def test_archive_read_partial_samples(temp_record):
"""
Test reading a sample range from the archive.
"""
archive_path = temp_record["archive_path"]
original_signal = temp_record["original_signal"]

start, stop = 100, 200
record = rdrecord(archive_path, sampfrom=start, sampto=stop)

assert record.p_signal.shape == (stop - start, original_signal.shape[1])
np.testing.assert_allclose(
record.p_signal, original_signal[start:stop], rtol=1e-2, atol=1e-3
)


def test_archive_missing_file_error(temp_record):
"""
Ensure appropriate error is raised when expected files are missing from the archive.
"""
archive_path = temp_record["archive_path"]

# Remove one file from archive (e.g. the .dat file)
with zipfile.ZipFile(archive_path, "a") as zf:
zf_name = [name for name in zf.namelist() if name.endswith(".dat")][0]
zf.fp = None # Prevent auto-close bug in some zipfile implementations
os.rename(archive_path, archive_path + ".bak")
with (
zipfile.ZipFile(archive_path + ".bak", "r") as zin,
zipfile.ZipFile(archive_path, "w") as zout,
):
for item in zin.infolist():
if not item.filename.endswith(".dat"):
zout.writestr(item, zin.read(item.filename))
os.remove(archive_path + ".bak")

with pytest.raises(FileNotFoundError, match=r".*\.dat.*"):
rdrecord(archive_path)
43 changes: 33 additions & 10 deletions wfdb/io/_header.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime
import os
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd

from wfdb.io import _signal
from wfdb.io import util
from wfdb.io import _signal, util
from wfdb.io.header import HeaderSyntaxError, rx_record, rx_segment, rx_signal

"""
Expand Down Expand Up @@ -278,7 +278,7 @@ def set_defaults(self):
for f in sfields:
self.set_default(f)

def wrheader(self, write_dir="", expanded=True):
def wrheader(self, write_dir="", expanded=True, wfdb_archive=None):
"""
Write a WFDB header file. The signals are not used. Before
writing:
Expand Down Expand Up @@ -325,7 +325,12 @@ def wrheader(self, write_dir="", expanded=True):
self.check_field_cohesion(rec_write_fields, list(sig_write_fields))

# Write the header file using the specified fields
self.wr_header_file(rec_write_fields, sig_write_fields, write_dir)
self.wr_header_file(
rec_write_fields,
sig_write_fields,
write_dir,
wfdb_archive=wfdb_archive,
)

def get_write_fields(self):
"""
Expand Down Expand Up @@ -508,7 +513,9 @@ def check_field_cohesion(self, rec_write_fields, sig_write_fields):
"Each file_name (dat file) specified must have the same byte offset"
)

def wr_header_file(self, rec_write_fields, sig_write_fields, write_dir):
def wr_header_file(
self, rec_write_fields, sig_write_fields, write_dir, wfdb_archive=None
):
"""
Write a header file using the specified fields. Converts Record
attributes into appropriate WFDB format strings.
Expand All @@ -522,6 +529,8 @@ def wr_header_file(self, rec_write_fields, sig_write_fields, write_dir):
being equal to a list of channels to write for each field.
write_dir : str
The directory in which to write the header file.
wfdb_archive : WFDBArchive, optional
If provided, write the header into this archive instead of to disk.

Returns
-------
Expand Down Expand Up @@ -583,7 +592,13 @@ def wr_header_file(self, rec_write_fields, sig_write_fields, write_dir):
comment_lines = ["# " + comment for comment in self.comments]
header_lines += comment_lines

util.lines_to_file(self.record_name + ".hea", write_dir, header_lines)
header_str = "\n".join(header_lines) + "\n"
hea_filename = os.path.basename(self.record_name) + ".hea"

if wfdb_archive:
wfdb_archive.write(hea_filename, header_str.encode("utf-8"))
else:
util.lines_to_file(hea_filename, write_dir, header_lines)


class MultiHeaderMixin(BaseHeaderMixin):
Expand Down Expand Up @@ -621,7 +636,7 @@ def set_defaults(self):
for field in self.get_write_fields():
self.set_default(field)

def wrheader(self, write_dir=""):
def wrheader(self, write_dir="", wfdb_archive=None):
"""
Write a multi-segment WFDB header file. The signals or segments are
not used. Before writing:
Expand Down Expand Up @@ -655,7 +670,7 @@ def wrheader(self, write_dir=""):
self.check_field_cohesion()

# Write the header file using the specified fields
self.wr_header_file(write_fields, write_dir)
self.wr_header_file(write_fields, write_dir, wfdb_archive=wfdb_archive)

def get_write_fields(self):
"""
Expand Down Expand Up @@ -733,7 +748,7 @@ def check_field_cohesion(self):
"The sum of the 'seg_len' fields do not match the 'sig_len' field"
)

def wr_header_file(self, write_fields, write_dir):
def wr_header_file(self, write_fields, write_dir, wfdb_archive=None):
"""
Write a header file using the specified fields.

Expand All @@ -744,6 +759,8 @@ def wr_header_file(self, write_fields, write_dir):
and their dependencies.
write_dir : str
The output directory in which the header is written.
wfdb_archive : WFDBArchive, optional
If provided, write the header into this archive instead of to disk.

Returns
-------
Expand Down Expand Up @@ -779,7 +796,13 @@ def wr_header_file(self, write_fields, write_dir):
comment_lines = ["# " + comment for comment in self.comments]
header_lines += comment_lines

util.lines_to_file(self.record_name + ".hea", write_dir, header_lines)
header_str = "\n".join(header_lines) + "\n"
hea_filename = os.path.basename(self.record_name) + ".hea"

if wfdb_archive:
wfdb_archive.write(hea_filename, header_str.encode("utf-8"))
else:
util.lines_to_file(hea_filename, write_dir, header_lines)

def get_sig_segments(self, sig_name=None):
"""
Expand Down
Loading