Skip to content

Allow passing region to GMTBackendEntrypoint.open_dataset #3932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
23 changes: 4 additions & 19 deletions pygmt/datasets/load_remote_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from typing import Any, Literal, NamedTuple

import xarray as xr
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import build_arg_list, kwargs_to_strings
from pygmt.src import which
from pygmt.helpers import kwargs_to_strings

with contextlib.suppress(ImportError):
# rioxarray is needed to register the rio accessor
Expand Down Expand Up @@ -581,22 +579,9 @@ def _load_remote_dataset(
raise GMTInvalidInput(msg)

fname = f"@{prefix}_{resolution}_{reg}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a lot of error messages like:

Error: h [ERROR]: Tile @S90W180.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W150.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W120.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W090.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W060.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W030.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90E000.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90E030.earth_age_01m_g.nc not found!

This is because, in the GMT backend, we use something like which("@earth_age_01m_g") to get the file path, which doesn't work well for tiled grids.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we used to do this:

    # Full path to the grid if not tiled grids.
    source = which(fname, download="a") if not resinfo.tiled else None
    # Manually add source to xarray.DataArray encoding to make the GMT accessors work.
    if source:
        grid.encoding["source"] = source

i.e. only add the source for non-tiled grids, so that the accessor's which call doesn't report this error. I'm thinking if it's possible to either 1) silence the which call (does verbose="q" work?), or 2) add some heuristic/logic to determine whether the source is a tiled grid before calling which in GMTBackendEntrypoint

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking if it's possible to either 1) silence the which call (does verbose="q" work?), or 2) add some heuristic/logic to determine whether the source is a tiled grid before calling which in GMTBackendEntrypoint

I think either works. Perhaps verbose="q" is easier?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit 5557b33

kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[dataset.kind]}
with Session() as lib:
with lib.virtualfile_out(kind=dataset.kind) as voutgrd:
lib.call_module(
module="read",
args=[fname, voutgrd, *build_arg_list(kwdict)],
)
grid = lib.virtualfile_to_raster(
kind=dataset.kind, outgrid=None, vfname=voutgrd
)

# Full path to the grid if not tiled grids.
source = which(fname, download="a") if not resinfo.tiled else None
# Manually add source to xarray.DataArray encoding to make the GMT accessors work.
if source:
grid.encoding["source"] = source
grid = xr.load_dataarray(
fname, engine="gmt", raster_kind=dataset.kind, region=region
)

# Add some metadata to the grid
grid.attrs["description"] = dataset.description
Expand Down
2 changes: 1 addition & 1 deletion pygmt/tests/test_xarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_xarray_accessor_grid_source_file_not_exist():
# Registration and gtype are correct.
assert grid.gmt.registration == GridRegistration.PIXEL
assert grid.gmt.gtype == GridType.GEOGRAPHIC
# The source grid file is undefined.
# The source grid file is undefined for tiled grids.
assert grid.encoding.get("source") is None
Comment on lines +147 to 148
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep grid.encoding["source"] as undefined/None for tiled grids (xref #3673 (comment))? Or select the first tile (e.g. S90E000.earth_relief_05m_p.nc)? May need to update this test depending on what we decide.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or select the first tile (e.g. S90E000.earth_relief_05m_p.nc)?

Sounds good.


# For a sliced grid, fallback to default registration and gtype, because the source
Expand Down
47 changes: 40 additions & 7 deletions pygmt/tests/test_xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def test_xarray_backend_load_dataarray():

def test_xarray_backend_gmt_open_nc_grid():
"""
Ensure that passing engine='gmt' to xarray.open_dataarray works for opening NetCDF
grids.
Ensure that passing engine='gmt' to xarray.open_dataarray works to open a netCDF
grid.
"""
with xr.open_dataarray(
"@static_earth_relief.nc", engine="gmt", raster_kind="grid"
Expand All @@ -52,10 +52,29 @@ def test_xarray_backend_gmt_open_nc_grid():
assert da.gmt.registration == GridRegistration.PIXEL


def test_xarray_backend_gmt_open_nc_grid_with_region_bbox():
"""
Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray
works to open a netCDF grid over a specific bounding box.
"""
with xr.open_dataarray(
"@static_earth_relief.nc",
engine="gmt",
raster_kind="grid",
region=[-52, -48, -18, -12],
) as da:
assert da.sizes == {"lat": 6, "lon": 4}
npt.assert_allclose(da.lat, [-17.5, -16.5, -15.5, -14.5, -13.5, -12.5])
npt.assert_allclose(da.lon, [-51.5, -50.5, -49.5, -48.5])
assert da.dtype == "float32"
assert da.gmt.gtype == GridType.GEOGRAPHIC
assert da.gmt.registration == GridRegistration.PIXEL


def test_xarray_backend_gmt_open_tif_image():
"""
Ensure that passing engine='gmt' to xarray.open_dataarray works for opening GeoTIFF
images.
Ensure that passing engine='gmt' to xarray.open_dataarray works to open a GeoTIFF
image.
"""
with xr.open_dataarray("@earth_day_01d", engine="gmt", raster_kind="image") as da:
assert da.sizes == {"band": 3, "y": 180, "x": 360}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coordinate names are y/x when region=None, but lat/lon when region is not None at L90 below. Need to fix this inconsistency.

Expand All @@ -64,6 +83,22 @@ def test_xarray_backend_gmt_open_tif_image():
assert da.gmt.registration == GridRegistration.PIXEL


def test_xarray_backend_gmt_open_tif_image_with_region_iso():
"""
Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray
works to open a GeoTIFF image over a specific ISO country code border.
"""
with xr.open_dataarray(
"@earth_day_01d", engine="gmt", raster_kind="image", region="BN"
) as da:
assert da.sizes == {"band": 3, "lat": 2, "lon": 2}
npt.assert_allclose(da.lat, [5.5, 4.5])
npt.assert_allclose(da.lon, [114.5, 115.5])
assert da.dtype == "uint8"
assert da.gmt.gtype == GridType.GEOGRAPHIC
assert da.gmt.registration == GridRegistration.PIXEL


def test_xarray_backend_gmt_load_grd_grid():
"""
Ensure that passing engine='gmt' to xarray.load_dataarray works for loading GRD
Expand All @@ -88,9 +123,7 @@ def test_xarray_backend_gmt_read_invalid_kind():
"""
with pytest.raises(
TypeError,
match=re.escape(
"GMTBackendEntrypoint.open_dataset() missing 1 required keyword-only argument: 'raster_kind'"
),
match=re.escape("missing a required argument: 'raster_kind'"),
):
xr.open_dataarray("nokind.nc", engine="gmt")

Expand Down
4 changes: 3 additions & 1 deletion pygmt/xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,11 @@ def __init__(self, xarray_obj: xr.DataArray):
# two columns of the shortened summary information of grdinfo.
if (_source := self._obj.encoding.get("source")) and Path(_source).exists():
with contextlib.suppress(ValueError):
self._registration, self._gtype = map( # type: ignore[assignment]
_registration, _gtype = map(
int, grdinfo(_source, per_column="n").split()[-2:]
)
self._registration = GridRegistration(_registration)
self._gtype = GridType(_gtype)

@property
def registration(self) -> GridRegistration:
Expand Down
15 changes: 10 additions & 5 deletions pygmt/xarray/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
An xarray backend for reading raster grid/image files using the 'gmt' engine.
"""

from collections.abc import Sequence
from typing import Literal

import xarray as xr
from pygmt._typing import PathLike
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import build_arg_list
from pygmt.helpers import build_arg_list, kwargs_to_strings
from pygmt.src.which import which
from xarray.backends import BackendEntrypoint

Expand Down Expand Up @@ -71,15 +72,17 @@ class GMTBackendEntrypoint(BackendEntrypoint):
"""

description = "Open raster (.grd, .nc or .tif) files in Xarray via GMT."
open_dataset_parameters = ("filename_or_obj", "raster_kind")
open_dataset_parameters = ("filename_or_obj", "raster_kind", "region")
url = "https://pygmt.org/dev/api/generated/pygmt.GMTBackendEntrypoint.html"

@kwargs_to_strings(region="sequence")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been think if we should avoid using the @kwargs_to_strings decorator in new functions/methods, and instead write a new function like seqjoin which does exactly the same thing.

def open_dataset( # type: ignore[override]
self,
filename_or_obj: PathLike,
*,
drop_variables=None, # noqa: ARG002
raster_kind: Literal["grid", "image"],
region: Sequence[float] | str | None = None,
# other backend specific keyword arguments
# `chunks` and `cache` DO NOT go here, they are handled by xarray
) -> xr.Dataset:
Expand All @@ -94,14 +97,17 @@ def open_dataset( # type: ignore[override]
:gmt-docs:`reference/features.html#grid-file-format`.
raster_kind
Whether to read the file as a "grid" (single-band) or "image" (multi-band).
region
Optional. The subregion of the grid or image to load, in the form of a
sequence [*xmin*, *xmax*, *ymin*, *ymax*] or an ISO country code.
"""
if raster_kind not in {"grid", "image"}:
msg = f"Invalid raster kind: '{raster_kind}'. Valid values are 'grid' or 'image'."
raise GMTInvalidInput(msg)

with Session() as lib:
with lib.virtualfile_out(kind=raster_kind) as voutfile:
kwdict = {"T": {"grid": "g", "image": "i"}[raster_kind]}
kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[raster_kind]}
lib.call_module(
module="read",
args=[filename_or_obj, voutfile, *build_arg_list(kwdict)],
Expand All @@ -111,9 +117,8 @@ def open_dataset( # type: ignore[override]
vfname=voutfile, kind=raster_kind
)
# Add "source" encoding
source = which(fname=filename_or_obj)
source: str | list = which(fname=filename_or_obj, verbose="q")
raster.encoding["source"] = (
source[0] if isinstance(source, list) else source
)
_ = raster.gmt # Load GMTDataArray accessor information
return raster.to_dataset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it's likely that the accessor information will be lost when converting via to_dataset.

Loading