Skip to content

Commit

Permalink
extract and save raster layers
Browse files Browse the repository at this point in the history
  • Loading branch information
wang-boyu committed Aug 13, 2022
1 parent c54624c commit baa1094
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 3 deletions.
72 changes: 69 additions & 3 deletions mesa_geo/raster_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class RasterLayer(RasterBase):

cells: List[List[Cell]]
_neighborhood_cache: Dict[Any, List[Coordinate]]
_attributes: Set[str]

def __init__(self, width, height, crs, total_bounds, cell_cls: Type[Cell] = Cell):
super().__init__(width, height, crs, total_bounds)
Expand All @@ -166,8 +167,13 @@ def __init__(self, width, height, crs, total_bounds, cell_cls: Type[Cell] = Cell
col.append(self.cell_cls(pos=(x, y), indices=(row_idx, col_idx)))
self.cells.append(col)

self._attributes = set()
self._neighborhood_cache = {}

@property
def attributes(self) -> Set[str]:
return self._attributes

@overload
def __getitem__(self, index: int) -> List[Cell]:
...
Expand Down Expand Up @@ -233,14 +239,53 @@ def coord_iter(self) -> Iterator[Tuple[Cell, int, int]]:
for col in range(self.height):
yield self.cells[row][col], row, col # cell, x, y

def apply_raster(self, data: np.ndarray, attr_name: str = None) -> None:
assert data.shape == (1, self.height, self.width)
def apply_raster(self, data: np.ndarray, attr_name: str | None = None) -> None:
"""Apply raster data to the cells.
Args:
data: 2D numpy array with shape (1, height, width).
attr_name: name of the attribute to be added to the cells. If None, a random name will be generated. Default is None.
Returns:
None
Raises:
ValueError: if the shape of the data is not (1, height, width).
"""
if data.shape != (1, self.height, self.width):
raise ValueError(
f"Data shape does not match raster shape. "
f"Expected {(1, self.height, self.width)}, received {data.shape}."
)
if attr_name is None:
attr_name = f"attribute_{len(self.cell_cls.__dict__)}"
self._attributes.add(attr_name)
for x in range(self.width):
for y in range(self.height):
setattr(self.cells[x][y], attr_name, data[0, self.height - y - 1, x])

def get_raster(self, attr_name: str | None = None) -> np.ndarray:
"""Returns the values of given attribute.
Args:
attr_name: The name of the attribute to return. If None, returns all attributes. Default is None.
Returns:
The values of given attribute.
"""
if attr_name is not None and attr_name not in self.attributes:
raise ValueError(
f"Attribute {attr_name} does not exist. "
f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all."
)
if attr_name is None:
num_bands = len(self.attributes)
attr_names = self.attributes
else:
num_bands = 1
attr_names = {attr_name}
data = np.empty((num_bands, self.height, self.width))
for ind, name in enumerate(attr_names):
for x in range(self.width):
for y in range(self.height):
data[ind, self.height - y - 1, x] = getattr(self.cells[x][y], name)
return data

def iter_neighborhood(
self,
pos: Coordinate,
Expand Down Expand Up @@ -399,7 +444,7 @@ def to_image(self, colormap) -> ImageLayer:

@classmethod
def from_file(
cls, raster_file, cell_cls: Type[Cell] = Cell, attr_name: str = None
cls, raster_file: str, cell_cls: Type[Cell] = Cell, attr_name: str | None = None
) -> RasterLayer:
with rio.open(raster_file, "r") as dataset:
values = dataset.read()
Expand All @@ -415,6 +460,27 @@ def from_file(
obj.apply_raster(values, attr_name=attr_name)
return obj

def to_file(self, raster_file: str, attr_name: str | None = None, driver="GTiff"):
"""Writes a raster layer to a file.
Args:
raster_file: Path to the raster file to write.
attr_name: Name of the attribute to write to the raster. If None, all attributes are written. Default is None.
driver: Driver to use for writing the raster. Default is "GTiff" (see GDAL docs at https://gdal.org/drivers/raster/index.html).
"""
data = self.get_raster(attr_name)
with rio.open(
raster_file,
"w",
driver=driver,
width=self.width,
height=self.height,
count=data.shape[0],
dtype=data.dtype,
crs=self.crs,
transform=self.transform,
) as dataset:
dataset.write(data)


class ImageLayer(RasterBase):
_values: np.ndarray
Expand Down
31 changes: 31 additions & 0 deletions tests/test_RasterLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,40 @@ def test_apple_raster(self):
[5, 6]]]
"""
self.assertEqual(self.raster_layer.cells[0][1].attribute_5, 3)
self.assertEqual(self.raster_layer.attributes, {"attribute_5"})

self.raster_layer.apply_raster(raster_data, attr_name="elevation")
self.assertEqual(self.raster_layer.cells[0][1].elevation, 3)
self.assertEqual(self.raster_layer.attributes, {"attribute_5", "elevation"})

with self.assertRaises(ValueError):
self.raster_layer.apply_raster(np.empty((1, 100, 100)))

def test_get_raster(self):
raster_data = np.array([[[1, 2], [3, 4], [5, 6]]])
self.raster_layer.apply_raster(raster_data)
"""
(x, y) coordinates:
(0, 2), (1, 2)
(0, 1), (1, 1)
(0, 0), (1, 0)
values:
[[[1, 2],
[3, 4],
[5, 6]]]
"""
self.raster_layer.apply_raster(raster_data, attr_name="elevation")
np.testing.assert_array_equal(
self.raster_layer.get_raster(attr_name="elevation"), raster_data
)

self.raster_layer.apply_raster(raster_data)
np.testing.assert_array_equal(
self.raster_layer.get_raster(), np.concatenate((raster_data, raster_data))
)
with self.assertRaises(ValueError):
self.raster_layer.get_raster("not_existing_attr")

def test_get_min_cell(self):
self.raster_layer.apply_raster(
Expand Down

0 comments on commit baa1094

Please sign in to comment.