Skip to content

Commit ecb933c

Browse files
weiji14seisman
andauthored
Implement gmt xarray BackendEntrypoint (#3919)
Co-authored-by: Dongdong Tian <[email protected]>
1 parent 2992d22 commit ecb933c

File tree

6 files changed

+210
-0
lines changed

6 files changed

+210
-0
lines changed

doc/api/index.rst

+8
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,14 @@ Getting metadata from tabular or grid data:
195195
info
196196
grdinfo
197197

198+
Xarray Integration
199+
------------------
200+
201+
.. autosummary::
202+
:toctree: generated
203+
204+
GMTBackendEntrypoint
205+
198206
Enums
199207
-----
200208

pygmt/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
x2sys_init,
6666
xyz2grd,
6767
)
68+
from pygmt.xarray import GMTBackendEntrypoint
6869

6970
# Start our global modern mode session
7071
_begin()

pygmt/tests/test_xarray_backend.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Tests for xarray 'gmt' backend engine.
3+
"""
4+
5+
import re
6+
7+
import numpy as np
8+
import numpy.testing as npt
9+
import pytest
10+
import xarray as xr
11+
from pygmt.enums import GridRegistration, GridType
12+
from pygmt.exceptions import GMTInvalidInput
13+
14+
15+
def test_xarray_backend_gmt_open_nc_grid():
16+
"""
17+
Ensure that passing engine='gmt' to xarray.open_dataarray works for opening NetCDF
18+
grids.
19+
"""
20+
with xr.open_dataarray(
21+
"@static_earth_relief.nc", engine="gmt", raster_kind="grid"
22+
) as da:
23+
assert da.sizes == {"lat": 14, "lon": 8}
24+
assert da.dtype == "float32"
25+
assert da.gmt.registration == GridRegistration.PIXEL
26+
assert da.gmt.gtype == GridType.GEOGRAPHIC
27+
28+
29+
def test_xarray_backend_gmt_open_tif_image():
30+
"""
31+
Ensure that passing engine='gmt' to xarray.open_dataarray works for opening GeoTIFF
32+
images.
33+
"""
34+
with xr.open_dataarray("@earth_day_01d", engine="gmt", raster_kind="image") as da:
35+
assert da.sizes == {"band": 3, "y": 180, "x": 360}
36+
assert da.dtype == "uint8"
37+
assert da.gmt.registration == GridRegistration.PIXEL
38+
assert da.gmt.gtype == GridType.GEOGRAPHIC
39+
40+
41+
def test_xarray_backend_gmt_load_grd_grid():
42+
"""
43+
Ensure that passing engine='gmt' to xarray.load_dataarray works for loading GRD
44+
grids.
45+
"""
46+
da = xr.load_dataarray(
47+
"@earth_relief_20m_holes.grd", engine="gmt", raster_kind="grid"
48+
)
49+
# Ensure data is in memory.
50+
assert isinstance(da.data, np.ndarray)
51+
npt.assert_allclose(da.min(), -4929.5)
52+
assert da.sizes == {"lat": 31, "lon": 31}
53+
assert da.dtype == "float32"
54+
assert da.gmt.registration == GridRegistration.GRIDLINE
55+
assert da.gmt.gtype == GridType.GEOGRAPHIC
56+
57+
58+
def test_xarray_backend_gmt_read_invalid_kind():
59+
"""
60+
Check that xarray.open_dataarray(..., engine="gmt") fails with missing or incorrect
61+
'raster_kind'.
62+
"""
63+
with pytest.raises(
64+
TypeError,
65+
match=re.escape(
66+
"GMTBackendEntrypoint.open_dataset() missing 1 required keyword-only argument: 'raster_kind'"
67+
),
68+
):
69+
xr.open_dataarray("nokind.nc", engine="gmt")
70+
71+
with pytest.raises(GMTInvalidInput):
72+
xr.open_dataarray(
73+
filename_or_obj="invalid.tif", engine="gmt", raster_kind="invalid"
74+
)

pygmt/xarray/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""
2+
PyGMT integration with Xarray accessors and backends.
3+
"""
4+
5+
from pygmt.xarray.backend import GMTBackendEntrypoint

pygmt/xarray/backend.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
An xarray backend for reading raster grid/image files using the 'gmt' engine.
3+
"""
4+
5+
from typing import Literal
6+
7+
import xarray as xr
8+
from pygmt._typing import PathLike
9+
from pygmt.clib import Session
10+
from pygmt.exceptions import GMTInvalidInput
11+
from pygmt.helpers import build_arg_list
12+
from pygmt.src.which import which
13+
from xarray.backends import BackendEntrypoint
14+
15+
16+
class GMTBackendEntrypoint(BackendEntrypoint):
17+
"""
18+
Xarray backend to read raster grid/image files using 'gmt' engine.
19+
20+
Internally, GMT uses the netCDF C library to read netCDF files, and GDAL for GeoTIFF
21+
and other raster formats. See :gmt-docs:`reference/features.html#grid-file-format`
22+
for more details about supported formats. This GMT engine can also read
23+
:gmt-docs:`GMT remote datasets <datasets/remote-data.html>` (file names starting
24+
with an `@`) directly, and pre-loads :class:`pygmt.GMTDataArrayAccessor` properties
25+
(in the '.gmt' accessor) for easy access to GMT-specific metadata and features.
26+
27+
When using :py:func:`xarray.open_dataarray` or :py:func:`xarray.load_dataarray` with
28+
``engine="gmt"``, the ``raster_kind`` parameter is required and can be either:
29+
30+
- ``"grid"``: for reading single-band raster grids
31+
- ``"image"``: for reading multi-band raster images
32+
33+
Examples
34+
--------
35+
Read a single-band netCDF file using ``raster_kind="grid"``
36+
37+
>>> import pygmt
38+
>>> import xarray as xr
39+
>>>
40+
>>> da_grid = xr.open_dataarray(
41+
... "@static_earth_relief.nc", engine="gmt", raster_kind="grid"
42+
... )
43+
>>> da_grid # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
44+
<xarray.DataArray 'z' (lat: 14, lon: 8)>...
45+
[112 values with dtype=float32]
46+
Coordinates:
47+
* lat (lat) float64... -23.5 -22.5 -21.5 -20.5 ... -12.5 -11.5 -10.5
48+
* lon (lon) float64... -54.5 -53.5 -52.5 -51.5 -50.5 -49.5 -48.5 -47.5
49+
Attributes:...
50+
Conventions: CF-1.7
51+
title: Produced by grdcut
52+
history: grdcut @earth_relief_01d_p -R-55/-47/-24/-10 -Gstatic_eart...
53+
description: Reduced by Gaussian Cartesian filtering (111.2 km fullwidt...
54+
actual_range: [190. 981.]
55+
long_name: elevation (m)
56+
57+
Read a multi-band GeoTIFF file using ``raster_kind="image"``
58+
59+
>>> da_image = xr.open_dataarray(
60+
... "@earth_night_01d", engine="gmt", raster_kind="image"
61+
... )
62+
>>> da_image # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
63+
<xarray.DataArray 'z' (band: 3, y: 180, x: 360)>...
64+
[194400 values with dtype=uint8]
65+
Coordinates:
66+
* y (y) float64... 89.5 88.5 87.5 86.5 ... -86.5 -87.5 -88.5 -89.5
67+
* x (x) float64... -179.5 -178.5 -177.5 -176.5 ... 177.5 178.5 179.5
68+
* band (band) uint8... 1 2 3
69+
Attributes:...
70+
long_name: z
71+
"""
72+
73+
description = "Open raster (.grd, .nc or .tif) files in Xarray via GMT."
74+
open_dataset_parameters = ("filename_or_obj", "raster_kind")
75+
url = "https://pygmt.org/dev/api/generated/pygmt.GMTBackendEntrypoint.html"
76+
77+
def open_dataset( # type: ignore[override]
78+
self,
79+
filename_or_obj: PathLike,
80+
*,
81+
drop_variables=None, # noqa: ARG002
82+
raster_kind: Literal["grid", "image"],
83+
# other backend specific keyword arguments
84+
# `chunks` and `cache` DO NOT go here, they are handled by xarray
85+
) -> xr.Dataset:
86+
"""
87+
Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`.
88+
89+
Parameters
90+
----------
91+
filename_or_obj
92+
File path to a netCDF (.nc), GeoTIFF (.tif) or other grid/image file format
93+
that can be read by GMT via the netCDF or GDAL C libraries. See also
94+
:gmt-docs:`reference/features.html#grid-file-format`.
95+
raster_kind
96+
Whether to read the file as a "grid" (single-band) or "image" (multi-band).
97+
"""
98+
if raster_kind not in {"grid", "image"}:
99+
msg = f"Invalid raster kind: '{raster_kind}'. Valid values are 'grid' or 'image'."
100+
raise GMTInvalidInput(msg)
101+
102+
with Session() as lib:
103+
with lib.virtualfile_out(kind=raster_kind) as voutfile:
104+
kwdict = {"T": {"grid": "g", "image": "i"}[raster_kind]}
105+
lib.call_module(
106+
module="read",
107+
args=[filename_or_obj, voutfile, *build_arg_list(kwdict)],
108+
)
109+
110+
raster: xr.DataArray = lib.virtualfile_to_raster(
111+
vfname=voutfile, kind=raster_kind
112+
)
113+
# Add "source" encoding
114+
source = which(fname=filename_or_obj)
115+
raster.encoding["source"] = (
116+
source[0] if isinstance(source, list) else source
117+
)
118+
_ = raster.gmt # Load GMTDataArray accessor information
119+
return raster.to_dataset()

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ all = [
5252
"rioxarray",
5353
]
5454

55+
[project.entry-points."xarray.backends"]
56+
gmt = "pygmt.xarray:GMTBackendEntrypoint"
57+
5558
[project.urls]
5659
"Homepage" = "https://www.pygmt.org"
5760
"Documentation" = "https://www.pygmt.org"

0 commit comments

Comments
 (0)