Skip to content

Commit d7ca533

Browse files
committed
Merge branch 'refactor/validate_data_input' into refactor/virtualfile_in
2 parents 80a178d + 9672f05 commit d7ca533

File tree

2 files changed

+50
-36
lines changed

2 files changed

+50
-36
lines changed

pygmt/accessors.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class GMTDataArrayAccessor:
126126
(<GridRegistration.GRIDLINE: 0>, <GridType.GEOGRAPHIC: 1>)
127127
"""
128128

129-
def __init__(self, xarray_obj):
129+
def __init__(self, xarray_obj: xr.DataArray):
130130
self._obj = xarray_obj
131131

132132
# Default to Gridline registration and Cartesian grid type
@@ -137,19 +137,19 @@ def __init__(self, xarray_obj):
137137
# two columns of the shortened summary information of grdinfo.
138138
if (_source := self._obj.encoding.get("source")) and Path(_source).exists():
139139
with contextlib.suppress(ValueError):
140-
self._registration, self._gtype = map(
140+
self._registration, self._gtype = map( # type: ignore[assignment]
141141
int, grdinfo(_source, per_column="n").split()[-2:]
142142
)
143143

144144
@property
145-
def registration(self):
145+
def registration(self) -> GridRegistration:
146146
"""
147147
Grid registration type :class:`pygmt.enums.GridRegistration`.
148148
"""
149149
return self._registration
150150

151151
@registration.setter
152-
def registration(self, value):
152+
def registration(self, value: GridRegistration | int):
153153
# TODO(Python>=3.12): Simplify to `if value not in GridRegistration`.
154154
if value not in GridRegistration.__members__.values():
155155
msg = (
@@ -160,14 +160,14 @@ def registration(self, value):
160160
self._registration = GridRegistration(value)
161161

162162
@property
163-
def gtype(self):
163+
def gtype(self) -> GridType:
164164
"""
165165
Grid coordinate system type :class:`pygmt.enums.GridType`.
166166
"""
167167
return self._gtype
168168

169169
@gtype.setter
170-
def gtype(self, value):
170+
def gtype(self, value: GridType | int):
171171
# TODO(Python>=3.12): Simplify to `if value not in GridType`.
172172
if value not in GridType.__members__.values():
173173
msg = (

pygmt/helpers/utils.py

+44-30
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pathlib import Path
1616
from typing import Any, Literal
1717

18+
import numpy as np
1819
import xarray as xr
1920
from pygmt.encodings import charset
2021
from pygmt.exceptions import GMTInvalidInput
@@ -39,11 +40,21 @@
3940
"ISO-8859-15",
4041
"ISO-8859-16",
4142
]
43+
# Type hints for the list of possible data kinds.
44+
Kind = Literal[
45+
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
46+
]
4247

4348

44-
def _validate_data_input( # noqa: PLR0912
45-
data=None, x=None, y=None, z=None, required_data=True, required_cols=2, kind=None
46-
):
49+
def _validate_data_input(
50+
data=None,
51+
x=None,
52+
y=None,
53+
z=None,
54+
required_data: bool = True,
55+
required_cols: int = 2,
56+
kind: Kind | None = None,
57+
) -> None:
4758
"""
4859
Check if the combination of data/x/y/z is valid.
4960
@@ -76,23 +87,23 @@ def _validate_data_input( # noqa: PLR0912
7687
>>> _validate_data_input(data=data, required_cols=3, kind="matrix")
7788
Traceback (most recent call last):
7889
...
79-
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
90+
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
8091
>>> _validate_data_input(
8192
... data=pd.DataFrame(data, columns=["x", "y"]),
8293
... required_cols=3,
8394
... kind="vectors",
8495
... )
8596
Traceback (most recent call last):
8697
...
87-
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
98+
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
8899
>>> _validate_data_input(
89100
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
90101
... kind="vectors",
91102
... required_cols=3,
92103
... )
93104
Traceback (most recent call last):
94105
...
95-
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
106+
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
96107
>>> _validate_data_input(data="infile", x=[1, 2, 3])
97108
Traceback (most recent call last):
98109
...
@@ -115,42 +126,49 @@ def _validate_data_input( # noqa: PLR0912
115126
GMTInvalidInput
116127
If the data input is not valid.
117128
"""
118-
if kind is None:
119-
kind = data_kind(data, required=required_data)
120-
129+
# Check if too much data is provided.
121130
if data is not None and any(v is not None for v in (x, y, z)):
122131
msg = "Too much data. Use either data or x/y/z."
123132
raise GMTInvalidInput(msg)
124133

134+
# Determine the data kind if not provided.
135+
kind = kind or data_kind(data, required=required_data)
136+
137+
# Check based on the data kind.
125138
match kind:
126-
case "empty":
127-
if x is None and y is None: # Both x and y are None.
139+
case "empty": # data is given via a series vectors like x/y/z.
140+
if x is None and y is None:
128141
msg = "No input data provided."
129142
raise GMTInvalidInput(msg)
130-
if x is None or y is None: # Either x or y is None.
143+
if x is None or y is None:
131144
msg = "Must provide both x and y."
132145
raise GMTInvalidInput(msg)
133146
if required_cols >= 3 and z is None:
134-
# Both x and y are not None, now check z.
135147
msg = "Must provide x, y, and z."
136148
raise GMTInvalidInput(msg)
137149
case "matrix": # 2-D numpy.ndarray
138150
if (actual_cols := data.shape[1]) < required_cols:
139-
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
151+
msg = (
152+
f"Need at least {required_cols} columns but {actual_cols} column(s) "
153+
"are given."
154+
)
140155
raise GMTInvalidInput(msg)
141156
case "vectors":
157+
# The if-else block should match the codes in the virtualfile_in function.
142158
if hasattr(data, "items") and not hasattr(data, "to_frame"):
143-
# Dict, pd.DataFrame, xr.Dataset
144-
arrays = [array for _, array in data.items()]
145-
if (actual_cols := len(arrays)) < required_cols:
146-
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
147-
raise GMTInvalidInput(msg)
148-
149-
# Loop over columns to make sure they're not None
150-
for idx, array in enumerate(arrays[:required_cols]):
151-
if array is None:
152-
msg = f"data needs {required_cols} columns but the {idx} column is None."
153-
raise GMTInvalidInput(msg)
159+
# Dict, pandas.DataFrame, or xarray.Dataset, but not pd.Series.
160+
_data = [array for _, array in data.items()]
161+
else:
162+
# Python list, tuple, numpy.ndarray, and pandas.Series types
163+
_data = np.atleast_2d(np.asanyarray(data).T)
164+
165+
# Check if the number of columns is sufficient.
166+
if (actual_cols := len(_data)) < required_cols:
167+
msg = (
168+
f"Need at least {required_cols} columns but {actual_cols} "
169+
"column(s) are given."
170+
)
171+
raise GMTInvalidInput(msg)
154172

155173

156174
def _is_printable_ascii(argstr: str) -> bool:
@@ -269,11 +287,7 @@ def _check_encoding(argstr: str) -> Encoding:
269287
return "ISOLatin1+"
270288

271289

272-
def data_kind(
273-
data: Any, required: bool = True
274-
) -> Literal[
275-
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
276-
]:
290+
def data_kind(data: Any, required: bool = True) -> Kind:
277291
r"""
278292
Check the kind of data that is provided to a module.
279293

0 commit comments

Comments
 (0)