Skip to content

Add type annotations to cairo_renderer.py #4393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions manim/renderer/cairo_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
camera_class: type[Camera] | None = None,
skip_animations: bool = False,
**kwargs: Any,
) -> None:
):
# All of the following are set to EITHER the value passed via kwargs,
# OR the value stored in the global config dict at the time of
# _instance construction_.
Expand All @@ -50,10 +50,10 @@ def __init__(
self.camera = camera_cls()
self._original_skipping_status = skip_animations
self.skip_animations = skip_animations
self.animations_hashes = []
self.animations_hashes: list[str | None] = []
self.num_plays = 0
self.time = 0
self.static_image = None
self.time = 0.0
self.static_image: PixelArray | None = None

def init_scene(self, scene: Scene) -> None:
self.file_writer: Any = self._file_writer_class(
Expand All @@ -65,8 +65,8 @@ def play(
self,
scene: Scene,
*args: Animation | Mobject | _AnimationBuilder,
**kwargs,
):
**kwargs: Any,
) -> None:
# Reset skip_animations to the original state.
# Needed when rendering only some animations, and skipping others.
self.skip_animations = self._original_skipping_status
Expand Down Expand Up @@ -159,7 +159,12 @@ def update_frame( # TODO Description in Docstring
kwargs["include_submobjects"] = include_submobjects
self.camera.capture_mobjects(mobjects, **kwargs)

def render(self, scene, time, moving_mobjects):
def render(
self,
scene: Scene,
time: float,
moving_mobjects: Iterable[Mobject] | None = None,
) -> None:
self.update_frame(scene, moving_mobjects)
self.add_frame(self.get_frame())

Expand All @@ -168,13 +173,13 @@ def get_frame(self) -> PixelArray:

Returns
-------
np.array
PixelArray
NumPy array of pixel values of each pixel in screen.
The shape of the array is height x width x 3.
"""
return np.array(self.camera.pixel_array)

def add_frame(self, frame: np.ndarray, num_frames: int = 1):
def add_frame(self, frame: PixelArray, num_frames: int = 1) -> None:
"""Adds a frame to the video_file_stream

Parameters
Expand All @@ -190,7 +195,7 @@ def add_frame(self, frame: np.ndarray, num_frames: int = 1):
self.time += num_frames * dt
self.file_writer.write_frame(frame, num_frames=num_frames)

def freeze_current_frame(self, duration: float):
def freeze_current_frame(self, duration: float) -> None:
"""Adds a static frame to the movie for a given duration. The static frame is the current frame.

Parameters
Expand All @@ -204,16 +209,18 @@ def freeze_current_frame(self, duration: float):
num_frames=int(duration / dt),
)

def show_frame(self):
"""Opens the current frame in the Default Image Viewer of your system."""
self.update_frame(ignore_skipping=True)
def show_frame(self, scene: Scene) -> None:
"""Opens the current frame in the Default Image Viewer
of your system.
"""
self.update_frame(scene, ignore_skipping=True)
self.camera.get_image().show()

def save_static_frame_data(
self,
scene: Scene,
static_mobjects: Iterable[Mobject],
) -> Iterable[Mobject] | None:
) -> PixelArray | None:
"""Compute and save the static frame, that will be reused at each frame
to avoid unnecessarily computing static mobjects.

Expand All @@ -226,8 +233,8 @@ def save_static_frame_data(

Returns
-------
typing.Iterable[Mobject]
The static image computed.
PixelArray | None
The static image computed. The return value is None if there are no static mobjects in the scene.
"""
self.static_image = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with other methods and variables self.static_image should ideally be renamed to self.static_frame, but that can be done in another PR.

if not static_mobjects:
Expand All @@ -236,9 +243,8 @@ def save_static_frame_data(
self.static_image = self.get_frame()
return self.static_image

def update_skipping_status(self):
"""
This method is used internally to check if the current
def update_skipping_status(self) -> None:
"""This method is used internally to check if the current
animation needs to be skipped or not. It also checks if
the number of animations that were played correspond to
the number of animations that need to be played, and
Expand Down
2 changes: 1 addition & 1 deletion manim/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
self.moving_mobjects: list[Mobject] = []
self.static_mobjects: list[Mobject] = []
self.time_progression: tqdm[float] | None = None
self.duration: float | None = None
self.duration: float = 0.0
Copy link
Contributor Author

@fmuenkel fmuenkel Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell setting self.duration = 0.0, will not cause any problems, but it avoids having to deal with it being None when calculating number of static frames or self.time.

self.last_t = 0.0
self.queue: Queue[SceneInteractAction] = Queue()
self.skip_animation_preview = False
Expand Down
54 changes: 29 additions & 25 deletions manim/utils/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,24 @@ class _Memoizer:
THRESHOLD_WARNING = 170_000

@classmethod
def reset_already_processed(cls):
def reset_already_processed(cls: type[_Memoizer]) -> None:
cls._already_processed.clear()

@classmethod
def check_already_processed_decorator(cls: _Memoizer, is_method: bool = False):
def check_already_processed_decorator(
cls: type[_Memoizer], is_method: bool = False
) -> Callable:
"""Decorator to handle the arguments that goes through the decorated function.
Returns _ALREADY_PROCESSED_PLACEHOLDER if the obj has been processed, or lets
the decorated function call go ahead.
Returns the value of ALREADY_PROCESSED_PLACEHOLDER if the obj has been processed,
or lets the decorated function call go ahead.

Parameters
----------
is_method
Whether the function passed is a method, by default False.
"""

def layer(func):
def layer(func: Callable[[Any], Any]) -> Callable:
# NOTE : There is probably a better way to separate both case when func is
# a method or a function.
if is_method:
Expand All @@ -82,9 +84,9 @@ def layer(func):
return layer

@classmethod
def check_already_processed(cls, obj: Any) -> Any:
def check_already_processed(cls: type[_Memoizer], obj: Any) -> Any:
"""Checks if obj has been already processed. Returns itself if it has not been,
or the value of _ALREADY_PROCESSED_PLACEHOLDER if it has.
or the value of ALREADY_PROCESSED_PLACEHOLDER if it has.
Marks the object as processed in the second case.

Parameters
Expand All @@ -101,7 +103,7 @@ def check_already_processed(cls, obj: Any) -> Any:
return cls._handle_already_processed(obj, lambda x: x)

@classmethod
def mark_as_processed(cls, obj: Any) -> None:
def mark_as_processed(cls: type[_Memoizer], obj: Any) -> None:
"""Marks an object as processed.

Parameters
Expand All @@ -114,10 +116,10 @@ def mark_as_processed(cls, obj: Any) -> None:

@classmethod
def _handle_already_processed(
cls,
obj,
cls: type[_Memoizer],
obj: Any,
default_function: Callable[[Any], Any],
):
) -> str | Any:
if isinstance(
obj,
(
Expand All @@ -142,11 +144,11 @@ def _handle_already_processed(

@classmethod
def _return(
cls,
cls: type[_Memoizer],
obj: Any,
obj_to_membership_sign: Callable[[Any], int],
default_func,
memoizing=True,
default_func: Callable[[Any], Any],
memoizing: bool = True,
) -> str | Any:
obj_membership_sign = obj_to_membership_sign(obj)
if obj_membership_sign in cls._already_processed:
Expand All @@ -172,9 +174,8 @@ def _return(


class _CustomEncoder(json.JSONEncoder):
def default(self, obj: Any):
"""
This method is used to serialize objects to JSON format.
def default(self, obj: Any) -> Any:
"""This method is used to serialize objects to JSON format.

If obj is a function, then it will return a dict with two keys : 'code', for
the code source, and 'nonlocals' for all nonlocalsvalues. (including nonlocals
Expand Down Expand Up @@ -233,22 +234,22 @@ def default(self, obj: Any):
# Serialize it with only the type of the object. You can change this to whatever string when debugging the serialization process.
return str(type(obj))

def _cleaned_iterable(self, iterable: Iterable[Any]):
def _cleaned_iterable(self, iterable: Iterable[Any]) -> list[Any] | dict[Any, Any]:
Copy link
Contributor

@chopan050 chopan050 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Despite the name, iterable can be either a Sequence or a dict, according to the code. Currently, iterable can't be really an Iterable, because _iter_check_list calculates its length, something which not all iterables have (like map, range or filter).

EDIT: it is actually pretty easy to rewrite _iter_check_list to accept any iterables. See my suggestion below.

The name of this parameter (and function) should probably be changed.

Now, we could simply type iterable as Sequence[Any] | dict[Any, Any], but may I suggest the following overloads to indicate that, if iterable is a Sequence, the function returns a list and, if it's a dict, it returns a dict:

Suggested change
def _cleaned_iterable(self, iterable: Iterable[Any]) -> list[Any] | dict[Any, Any]:
@overload
def _cleaned_iterable(self, iterable: Sequence[Any]) -> list[Any]: ...
@overload
def _cleaned_iterable(self, iterable: dict[Any, Any]) -> dict[Any, Any]: ...
def _cleaned_iterable(self, iterable):

"""Check for circular reference at each iterable that will go through the JSONEncoder, as well as key of the wrong format.

If a key with a bad format is found (i.e not a int, string, or float), it gets replaced byt its hash using the same process implemented here.
If a circular reference is found within the iterable, it will be replaced by the string "already processed".
If a circular reference is found within the iterable, it will be replaced by the value of ALREADY_PROCESSED_PLACEHOLDER.

Parameters
----------
iterable
The iterable to check.
"""

def _key_to_hash(key):
def _key_to_hash(key) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _key_to_hash(key) -> int:
def _key_to_hash(key: Any) -> int:

return zlib.crc32(json.dumps(key, cls=_CustomEncoder).encode())

def _iter_check_list(lst):
def _iter_check_list(lst: list[Any]) -> list[Any]:
Copy link
Contributor

@chopan050 chopan050 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _iter_check_list(lst: list[Any]) -> list[Any]:
def _iter_check_list(lst: Sequence[Any]) -> list[Any]:

It's also technically pretty easy to rewrite this function to allow any iterables, but I'm not so sure about allowing potentially infinite iterables:

        def _iter_check_list(lst: Iterable[Any]) -> list[Any]:
            processed_list = []
            for el in lst:
                el = _Memoizer.check_already_processed(el)
                if isinstance(el, Iterable):
                    new_value = _iter_check_list(el)
                elif isinstance(el, dict):
                    new_value = _iter_check_dict(el)
                else:
                    new_value = el
                processed_list.append(new_value)
            return processed_list

processed_list = [None] * len(lst)
for i, el in enumerate(lst):
el = _Memoizer.check_already_processed(el)
Expand All @@ -261,7 +262,7 @@ def _iter_check_list(lst):
processed_list[i] = new_value
return processed_list

def _iter_check_dict(dct):
def _iter_check_dict(dct: dict) -> dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _iter_check_dict(dct: dict) -> dict:
def _iter_check_dict(dct: dict[Any, Any]) -> dict[Any, Any]:

processed_dict = {}
for k, v in dct.items():
v = _Memoizer.check_already_processed(v)
Expand All @@ -285,8 +286,11 @@ def _iter_check_dict(dct):
return _iter_check_list(iterable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please replace

        if isinstance(iterable, (list, tuple)):

with

        if isinstance(iterable, Sequence):

?

Lists and tuples pass that check.

elif isinstance(iterable, dict):
return _iter_check_dict(iterable)
else:
# mypy requires this line, even though it should not be reached.
return iterable
Comment on lines +290 to +291
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MyPy requires this line, because, in the hypothetical case that _cleaned_iterable() receives a different object that's neither an iterable or a dictionary, the function would implicitly return None.

Now, as you mention, this shouldn't be reached, because the object should always be an iterable or a dictionary. In this case, I suggest raising an exception instead of silently returning the same iterable. In this way, if _cleaned_iterable() is being passed something different, we can catch the bug instead of silently omitting it:

Suggested change
# mypy requires this line, even though it should not be reached.
return iterable
raise TypeError("'iterable' is neither an iterable nor a dictionary.")


def encode(self, obj: Any):
def encode(self, obj: Any) -> str:
"""Overriding of :meth:`JSONEncoder.encode`, to make our own process.

Parameters
Expand All @@ -305,7 +309,7 @@ def encode(self, obj: Any):
return super().encode(obj)


def get_json(obj: dict):
def get_json(obj: Any) -> str:
"""Recursively serialize `object` to JSON using the :class:`CustomEncoder` class.

Parameters
Expand All @@ -324,7 +328,7 @@ def get_json(obj: dict):
def get_hash_from_play_call(
scene_object: Scene,
camera_object: Camera | OpenGLCamera,
animations_list: Iterable[Animation],
animations_list: Iterable[Animation] | None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An issue with this is that animations_list is expected to be an iterable inside the function. If it's None, it will crash.

I noticed that CairoRenderer.play(self, scene, *args, **kwargs) passes scene.animations which is typed as list[Animation] | None. However, since scene.compile_animation_data(*args, **kwargs) is called before that, scene.animations will always be a list[Animation] at that point.

Therefore, instead of typing animations_list as Iterable[Animation] | None, you have to make an assertion inside CairoRenderer.play() that scene.animations is not None before the call to get_hash_from_play_call().

current_mobjects_list: Iterable[Mobject],
) -> str:
"""Take the list of animations and a list of mobjects and output their hashes. This is meant to be used for `scene.play` function.
Expand Down
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ ignore_errors = True
[mypy-manim.mobject.vector_field]
ignore_errors = True

[mypy-manim.renderer.cairo_renderer]
ignore_errors = True

[mypy-manim.renderer.opengl_renderer]
ignore_errors = True

Expand Down
Loading