Skip to content

Commit

Permalink
Merge pull request #4 from apockill/feature/add-template
Browse files Browse the repository at this point in the history
Wew, tests pass and run_headless doesn't immediately fail anymore
  • Loading branch information
apockill authored Dec 25, 2023
2 parents 0e4d029 + 9abbe09 commit 3a1622c
Show file tree
Hide file tree
Showing 14 changed files with 37 additions and 23 deletions.
3 changes: 3 additions & 0 deletions assets/background.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions assets/critter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions assets/food.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions crawlai/game_scripts/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,4 @@ def step(grid: Grid, pool: ThreadPool) -> None:

def close(self) -> None:
self.pool.terminate()
self.pool.join()
2 changes: 1 addition & 1 deletion crawlai/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, width: int, height: int) -> None:
self._hash_cache: int | None = None

# Grid state
self.array: npt.NDArray[np.int8] = np.zeros(
self.array: npt.NDArray[np.int_] = np.zeros(
shape=(width, height), dtype=np.int_
)
"""Holds the instance ids of each object. 0 means empty"""
Expand Down
8 changes: 3 additions & 5 deletions crawlai/items/critter/base_critter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Any

from godot.bindings import ResourceLoader
from godot.bindings import Node, ResourceLoader

from crawlai.grid_item import GridItem
from crawlai.items.food import Food
from crawlai.math_utils import clamp
from crawlai.position import Position
from crawlai.turn import Turn

_critter_resource = ResourceLoader.load("res://Game/Critter/Critter.tscn")
_critter_resource = ResourceLoader.load("assets/critter.png")


class BaseCritter(GridItem):
Expand Down Expand Up @@ -38,7 +36,7 @@ def _tick_stats(self) -> None:
self.age += 1
self.health -= self.HEALTH_TICK_PENALTY

def _load_instance(self) -> Any:
def _load_instance(self) -> Node:
return _critter_resource.instance()

def perform_action_onto(self, other: "GridItem") -> None:
Expand Down
6 changes: 5 additions & 1 deletion crawlai/items/food.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from godot.bindings import Node, ResourceLoader

from crawlai.grid import Grid
from crawlai.grid_item import GridItem
from crawlai.math_utils import clamp

_food_resource = ResourceLoader.load("assets/food.png")


class Food(GridItem):
MAX_NUTRITION = 100
Expand All @@ -16,7 +20,7 @@ def take_nutrition(self, amount: int) -> int:
self.nutrition -= to_take
return amount

def _load_instance(self):
def _load_instance(self) -> Node:
return _food_resource.instance()

def perform_action_onto(self, other: "GridItem") -> None:
Expand Down
2 changes: 1 addition & 1 deletion crawlai/model/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(

# TODO: Figure out how to get len(AiCritterMixin.CHOICES)
self._action_spec = array_spec.BoundedArraySpec(
shape=(), dtype=np.int, minimum=0, maximum=n_choices - 1, name="action"
shape=(), dtype=np.int_, minimum=0, maximum=n_choices - 1, name="action"
)

@property
Expand Down
8 changes: 4 additions & 4 deletions crawlai/model/extract_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
from crawlai.grid import Grid
from crawlai.position import Position

INPUT_DTYPE = np.int
INPUT_DTYPE = np.int_
"""The smallest int type accepted by tensorflow"""

# TODO: Implement caching
_generate_layered_grid_lock = RLock()
"""Because generating the full layered grid is a bit expensive, it's best for
one thread to process this and the rest of them to use the cached result. """
_instance_grid_cache: dict[Hashable, npt.NDArray[np.int8]] = {}
_instance_grid_cache: dict[Hashable, npt.NDArray[np.int_]] = {}
"""Holds a dictionary of a single value, of format
{hash(grid.array.data.tobytes(), ): instance_grid} """


def _generate_layered_grid(
grid: Grid, layers: dict[str, int], radius: int
) -> npt.NDArray[np.int8]:
) -> npt.NDArray[np.int_]:
"""Converts the grid of shape (x, y) to (x, y, obj_layers)
The 0th index of the grid always represents boundaries or walls.
Expand Down Expand Up @@ -52,7 +52,7 @@ def _generate_layered_grid(

def get_instance_grid(
grid: Grid, pos: Position, radius: int, layers: dict[str, int]
) -> npt.NDArray[np.int8]:
) -> npt.NDArray[np.int_]:
"""Get a numpy array of obj IDs surrounding a particular area.
This function will always return an array of shape (radius, radius),
where the value is the object ID.
Expand Down
5 changes: 3 additions & 2 deletions crawlai/scripts/run_headless.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tests import monkeypatch_godot_import # noqa: F401 # isort: skip
import random

import tensorflow as tf
Expand All @@ -6,7 +7,7 @@
from tests.helpers import Timer


def main(ticks_per_report: int) -> None:
def main(ticks_per_report: int = 1000) -> None:
tf.random.set_seed(1)
random.seed("benchmark")

Expand All @@ -29,4 +30,4 @@ def main(ticks_per_report: int) -> None:


if __name__ == "__main__":
main(ticks_per_report=1000)
main()
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import tests.monkeypatch_godot_import # noqa: F401 # isort: skip
import random
from collections.abc import Generator

import pytest

from crawlai.game_scripts.world import World
from crawlai.model import extract_inputs
from tests import helpers, monkeypatch_godot_import # noqa:F401
from tests import helpers


@pytest.fixture(autouse=True)
Expand Down
6 changes: 3 additions & 3 deletions tests/godot_mock/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def get_instance_id(self):
return self._instance_id

def add_child(self, node):
assert isinstance(node, Node)
assert isinstance(node, Node), f"Expected {Node}, got {type(node)}"

def set_position(self, vector2):
assert isinstance(vector2, Vector2)
assert isinstance(vector2, Vector2), f"Expected {Vector2}, got {type(vector2)}"
assert vector2
self._position = vector2

Expand All @@ -47,7 +47,7 @@ class ResourceLoader:
@staticmethod
def load(resource_path: str):
"""Verify the resource path exists"""
resource_path = resource_path.replace("res://", "")
resource_path = Path(resource_path).absolute().resolve(strict=True)
assert Path(resource_path).is_file()

class Unloaded:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_get_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
argvalues=test_get_grid_around_parameters,
)
def test_get_grid_around(
pos: tuple[int, int], radius: int, expected_occupied_layers: npt.NDArray[np.int8]
pos: tuple[int, int], radius: int, expected_occupied_layers: npt.NDArray[np.int_]
) -> None:
"""Creates a grid of shape:
[[0 0 0 0 0]
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_instance_grid_is_cached() -> None:
), "The caches should start out empty!"

def validate_cache_changes(
changed: bool, last_cache: dict[Hashable, npt.NDArray[np.int8]]
changed: bool, last_cache: dict[Hashable, npt.NDArray[np.int_]]
) -> None:
cache = extract_inputs._instance_grid_cache
assert len(cache) == 1
Expand Down
6 changes: 3 additions & 3 deletions tests/test_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def test_random_movement_persists_safely(
undying_base_critter_type: BaseCritter, world: World
) -> None:
"""Basically, make sure things don't crash during normal use"""
world.min_num_critters = 10
world.min_num_food = 100
world.min_num_critters = 9
world.min_num_food = 101
world.grid_height = 25
world.grid_width = 15
n_ticks = 100
n_ticks = 132

world._ready()
validate_grid(world.grid)
Expand Down

0 comments on commit 3a1622c

Please sign in to comment.