diff --git a/gtfparse/__init__.py b/gtfparse/__init__.py index 99d19c3..aabf120 100644 --- a/gtfparse/__init__.py +++ b/gtfparse/__init__.py @@ -13,6 +13,7 @@ from .attribute_parsing import expand_attribute_strings from .create_missing_features import create_missing_features from .parsing_error import ParsingError +from .write_gtf import write_gtf from .read_gtf import ( read_gtf, parse_gtf, diff --git a/gtfparse/write_gtf.py b/gtfparse/write_gtf.py new file mode 100644 index 0000000..1ea9c55 --- /dev/null +++ b/gtfparse/write_gtf.py @@ -0,0 +1,23 @@ +import polars +from pathlib import Path +import typing as t + + +COMMONS_COL = ['seqname', 'source', 'feature', 'start', 'end', 'score', 'strand', 'frame'] + + +def write_gtf(df: polars.DataFrame, export_path: str | Path, headers: t.List[str] = None): + headers = headers or [] + with open(export_path, 'w') as f: + for header in headers: + f.write(f"{header}\n") + for row in df.iter_rows(named=True): + f.write(f"{commons_cols(row)}\t{custom_fields(row)}\n") + + +def commons_cols(row) -> str : + return "\t".join([str(row[field] or '.') for field in COMMONS_COL]) + + +def custom_fields(row) -> str: + return "; ".join([f'{field} "{row[field]}"' for field in row.keys() if (field not in COMMONS_COL) and (row[field])]) diff --git a/tests/test_write_gtf.py b/tests/test_write_gtf.py new file mode 100644 index 0000000..dc40a48 --- /dev/null +++ b/tests/test_write_gtf.py @@ -0,0 +1,14 @@ +from gtfparse import read_gtf, write_gtf +from .data import data_path +from polars import DataFrame + +REFSEQ_GTF_PATH = data_path("refseq.ucsc.small.gtf") + + +def test_write_gtf(tmp_path): + expected_gtf = read_gtf(REFSEQ_GTF_PATH) + write_gtf(expected_gtf, tmp_path/"dummy_gtf.gtf") + created_gtf = read_gtf(str(tmp_path/"dummy_gtf.gtf")) + assert isinstance(created_gtf, DataFrame) + assert expected_gtf == created_gtf +