Skip to content

Commit

Permalink
Implemented post_process in Altair based components (#2641)
Browse files Browse the repository at this point in the history
* implemented post-process

* updated the docs

* added test for post_process

* corrected docs

* corrected docs

* fixed bugs and added test for post_process

* fixing tests

* fixing tests

---------

Co-authored-by: Jan Kwakkel <[email protected]>
  • Loading branch information
sanika-n and quaquel authored Jan 31, 2025
1 parent ffdd525 commit c58af09
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
18 changes: 10 additions & 8 deletions mesa/visualization/components/altair_components.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""Altair based solara components for visualization mesa spaces."""

import contextlib
import warnings

import altair as alt
import solara

with contextlib.suppress(ImportError):
import altair as alt

from mesa.experimental.cell_space import DiscreteSpace, Grid
from mesa.space import ContinuousSpace, _Grid
from mesa.visualization.utils import update_counter
Expand All @@ -30,7 +27,7 @@ def make_altair_space(
Args:
agent_portrayal: Function to portray agents.
propertylayer_portrayal: not yet implemented
post_process :not yet implemented
post_process :A user specified callable that will be called with the Chart instance from Altair. Allows for fine tuning plots (e.g., control ticks)
space_drawing_kwargs : not yet implemented
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
Expand All @@ -46,13 +43,15 @@ def agent_portrayal(a):
return {"id": a.unique_id}

def MakeSpaceAltair(model):
return SpaceAltair(model, agent_portrayal)
return SpaceAltair(model, agent_portrayal, post_process=post_process)

return MakeSpaceAltair


@solara.component
def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None):
def SpaceAltair(
model, agent_portrayal, dependencies: list[any] | None = None, post_process=None
):
"""Create an Altair-based space visualization component.
Returns:
Expand All @@ -65,6 +64,9 @@ def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None):
space = model.space

chart = _draw_grid(space, agent_portrayal)
# Apply post-processing if provided
if post_process is not None:
chart = post_process(chart)
solara.FigureAltair(chart)


Expand Down Expand Up @@ -159,7 +161,7 @@ def _draw_grid(space, agent_portrayal):
# no y-axis label
"y": alt.Y("y", axis=None, type=x_y_type),
"tooltip": [
alt.Tooltip(key, type=alt.utils.infer_vegalite_type([value]))
alt.Tooltip(key, type=alt.utils.infer_vegalite_type_for_pandas([value]))
for key, value in all_agent_data[0].items()
if key not in invalid_tooltips
],
Expand Down
6 changes: 5 additions & 1 deletion mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ def SolaraViz(
reduce update frequency,resulting in faster execution.
"""
if components == "default":
components = [components_altair.make_altair_space()]
components = [
components_altair.make_altair_space(
agent_portrayal=None, propertylayer_portrayal=None, post_process=None
)
]
if model_params is None:
model_params = {}

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ network = [
viz = [
"matplotlib",
"solara",
"altair",
]
# Dev and CI stuff
dev = [
Expand Down
45 changes: 39 additions & 6 deletions tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import mesa
import mesa.visualization.components.altair_components
import mesa.visualization.components.matplotlib_components
from mesa.space import MultiGrid
from mesa.visualization.components.altair_components import make_altair_space
from mesa.visualization.components.matplotlib_components import make_mpl_space_component
from mesa.visualization.solara_viz import (
Slider,
Expand Down Expand Up @@ -101,17 +103,22 @@ def test_call_space_drawer(mocker): # noqa: D103
mesa.visualization.components.altair_components, "SpaceAltair"
)

class MockAgent(mesa.Agent):
def __init__(self, model):
super().__init__(model)

class MockModel(mesa.Model):
def __init__(self, seed=None):
super().__init__(seed=seed)
self.grid = MultiGrid(width=10, height=10, torus=True)
a = MockAgent(self)
self.grid.place_agent(a, (5, 5))

model = MockModel()
mocker.patch.object(mesa.Model, "__init__", return_value=None)

agent_portrayal = {
"marker": "circle",
"color": "gray",
}
def agent_portrayal(agent):
return {"marker": "o", "color": "gray"}

propertylayer_portrayal = None
# initialize with space drawer unspecified (use default)
# component must be rendered for code to run
Expand All @@ -131,7 +138,33 @@ def __init__(self, seed=None):
solara.render(SolaraViz(model))
# should call default method with class instance and agent portrayal
assert mock_space_matplotlib.call_count == 0
assert mock_space_altair.call_count == 0
assert mock_space_altair.call_count == 1 # altair is the default method

# checking if SpaceAltair is working as intended with post_process

mock_post_process = mocker.MagicMock()
solara.render(
SolaraViz(
model,
components=[
make_altair_space(
agent_portrayal,
propertylayer_portrayal,
mock_post_process,
)
],
)
)

args, kwargs = mock_space_altair.call_args
assert args == (model, agent_portrayal)
assert kwargs == {"post_process": mock_post_process}
mock_post_process.assert_called_once()
assert mock_space_matplotlib.call_count == 0

mock_space_altair.reset_mock()
mock_space_matplotlib.reset_mock()
mock_post_process.reset_mock()

# specify a custom space method
class AltSpace:
Expand Down

0 comments on commit c58af09

Please sign in to comment.