diff --git a/samples/sample_tap_csv/client.py b/samples/sample_tap_csv/client.py index 6aa3b1850..8aef69307 100644 --- a/samples/sample_tap_csv/client.py +++ b/samples/sample_tap_csv/client.py @@ -1,7 +1,7 @@ from __future__ import annotations +import abc import csv -import datetime import typing as t import fsspec @@ -11,6 +11,8 @@ from singer_sdk.streams.core import REPLICATION_INCREMENTAL if t.TYPE_CHECKING: + import datetime + from singer_sdk.helpers.types import Context, Record from singer_sdk.tap_base import Tap @@ -18,19 +20,15 @@ SDC_META_MODIFIED_AT = "_sdc_modified_at" -def _to_datetime(value: float) -> str: - return datetime.datetime.fromtimestamp(value).astimezone() - - -class CSVStream(Stream): - """CSV stream class.""" +class FileStream(Stream, metaclass=abc.ABCMeta): + """Abstract base class for file streams.""" def __init__( self, tap: Tap, - name: str | None = None, + name: str, *, - partitions: list[str] | None = None, + partitions: list[Context] | None = None, ) -> None: # TODO(edgarmondragon): Build schema from CSV file. schema = { @@ -46,34 +44,17 @@ def __init__( # TODO(edgarrmondragon): Make this None if the filesytem does not support it. self.replication_key = SDC_META_MODIFIED_AT - - self._partitions = partitions or [] - - self.filesystem: fsspec.AbstractFileSystem = fsspec.filesystem("local") self._sync_start_time = utc_now() + self.filesystem: fsspec.AbstractFileSystem = fsspec.filesystem("local") + self._partitions = partitions or [] @property def partitions(self) -> list[Context]: return self._partitions - def _read_file(self, path: str) -> t.Iterable[Record]: - # Make these configurable. - delimiter = "," - quotechar = '"' - escapechar = None - doublequote = True - lineterminator = "\r\n" - - with self.filesystem.open(path, mode="r") as file: - reader = csv.DictReader( - file, - delimiter=delimiter, - quotechar=quotechar, - escapechar=escapechar, - doublequote=doublequote, - lineterminator=lineterminator, - ) - yield from reader + @abc.abstractmethod + def read_file(self, context: Context | None) -> t.Iterable[Record]: + """Return a generator of records from the file.""" def get_records( self, @@ -97,6 +78,29 @@ def get_records( self.logger.info("File has not been modified since last read, skipping") return - for record in self._read_file(path): + for record in self.read_file(path): record[SDC_META_MODIFIED_AT] = mtime or self._sync_start_time yield record + + +class CSVStream(FileStream): + """CSV stream class.""" + + def read_file(self, path: str) -> t.Iterable[Record]: + # Make these configurable. + delimiter = "," + quotechar = '"' + escapechar = None + doublequote = True + lineterminator = "\r\n" + + with self.filesystem.open(path, mode="r") as file: + reader = csv.DictReader( + file, + delimiter=delimiter, + quotechar=quotechar, + escapechar=escapechar, + doublequote=doublequote, + lineterminator=lineterminator, + ) + yield from reader