Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mahjong] Reduce compilation/run time #1070

Open
OkanoShinri opened this issue Oct 27, 2023 · 2 comments
Open

[Mahjong] Reduce compilation/run time #1070

OkanoShinri opened this issue Oct 27, 2023 · 2 comments

Comments

@OkanoShinri
Copy link
Collaborator

benchmark.py
from pgx._mahjong._mahjong2 import (
    Mahjong,
    _discard,
    _selfkan,
    _riichi,
    _tsumo,
    _ron,
    _pon,
    _minkan,
    _pass,
)
import jax
import time
import sys

# func(state, action)
functions1 = {"_discard": _discard, "_selfkan": _selfkan}

# func(state)
functions2 = {
    "_riichi": _riichi,
    "_tsumo": _tsumo,
    "_ron": _ron,
    "_pon": _pon,
    "_minkan": _minkan,
    "_pass": _pass,
}


env = Mahjong()

func_name = sys.argv[1]
if func_name in functions1:
    func = functions1[func_name]

    key = jax.random.PRNGKey(352)
    state = env.init(key=key)

    time_sta = time.perf_counter()
    jax.jit(func)(state, 0)
    time_end = time.perf_counter()
    delta = (time_end - time_sta) * 1000
    exp = jax.make_jaxpr(func)(state, 0)
    n_line = len(str(exp).split("\n"))
    print(f"| `{func.__name__}` | {n_line} | {delta:.1f}ms |")

elif func_name in functions2:
    func = functions2[func_name]

    key = jax.random.PRNGKey(352)
    state = env.init(key=key)

    time_sta = time.perf_counter()
    jax.jit(func)(state)
    time_end = time.perf_counter()
    delta = (time_end - time_sta) * 1000
    exp = jax.make_jaxpr(func)(state)
    n_line = len(str(exp).split("\n"))
    print(f"| `{func.__name__}` | {n_line} | {delta:.1f}ms |")
benchmark.sh
echo "| function name | # expr lines | compile time |"
echo "| :--- | ---: | ---: |"
for funcname in _discard _selfkan _riichi _tsumo _ron _pon _minkan _pass
do
    python3 benchmark.py $funcname
done
  • step内の主要関数抜粋
function name # expr lines compile time
_discard 12245 4594.9ms
_selfkan 333 202.9ms
_riichi 934 447.8ms
_tsumo 2328 1114.8ms
_ron 2280 1024.9ms
_pon 261 112.6ms
_minkan 324 122.2ms
_pass 8539 2385.0ms
@sotetsuk sotetsuk changed the title [Mahjong] Reduce compilation time [Mahjong] Reduce compilation/run time Nov 8, 2023
@OkanoShinri
Copy link
Collaborator Author

OkanoShinri commented Nov 8, 2023

benchmark.py
from pgx.mahjong._env import (
    Mahjong,
    _discard,
    _selfkan,
    _riichi,
    _tsumo,
    _ron,
    _pon,
    _minkan,
    _pass,
)
import jax
import time
import sys
import timeit

# func(state, action)
functions1 = {"_discard": _discard, "_selfkan": _selfkan}

# func(state)
functions2 = {
    "_riichi": _riichi,
    "_tsumo": _tsumo,
    "_ron": _ron,
    "_pon": _pon,
    "_minkan": _minkan,
    "_pass": _pass,
}


env = Mahjong()
N = 10

func_name = sys.argv[1]
if func_name in functions1:
    func = functions1[func_name]

    key = jax.random.PRNGKey(352)
    state = env.init(key=key)

    time_sta = time.perf_counter()
    jax.jit(func)(state, 0)
    time_end = time.perf_counter()
    delta = (time_end - time_sta) * 1000
    exp = jax.make_jaxpr(func)(state, 0)
    n_line = len(str(exp).split("\n"))

    jit_func = jax.jit(func)
    run_delta = timeit.timeit(
        "jit_func(state, 0)", globals=globals(), number=N
    )

    print(
        f"| `{func.__name__}` | {n_line} | {delta:.1f}ms | {run_delta/N*1000000:.1f}μs |"
    )

elif func_name in functions2:
    func = functions2[func_name]

    key = jax.random.PRNGKey(352)
    state = env.init(key=key)

    time_sta = time.perf_counter()
    jax.jit(func)(state)
    time_end = time.perf_counter()
    delta = (time_end - time_sta) * 1000
    exp = jax.make_jaxpr(func)(state)
    n_line = len(str(exp).split("\n"))

    jit_func = jax.jit(func)
    run_delta = timeit.timeit("jit_func(state)", globals=globals(), number=N)

    print(
        f"| `{func.__name__}` | {n_line} | {delta:.1f}ms | {run_delta/N*1000000:.1f}μs |"
    )

benchmark.sh
echo "| function name | # expr lines | compile time | running time |"
echo "| :--- | ---: | ---: | ---: |"
for funcname in _discard _selfkan _riichi _tsumo _ron _pon _minkan _pass
do
    python3 benchmark.py $funcname
done
  • 手元での実行時間
function name # expr lines compile time running time
_discard 12245 5012.2ms 264.3μs
- _draw 5517 2137.6ms 244.9μs
-- _make_legal_action_mask 3048 1281.3ms 106.2μs
-- _make_legal_action_mask_w_riichi 2394 1088.4ms 78.7μs
_selfkan 333 198.3ms 64.7μs
_riichi 934 427.5ms 104.6μs
_tsumo 2328 1095.3ms 77.1μs
_ron 2280 1048.6ms 104.6μs
_pon 261 117.2ms 64.7μs
_minkan 324 133.0ms 69.7μs
_pass 8539 2526.0ms 223.7μs
can_tsumo 225 115.3ms 12.5μs
can_ron 245 122.4ms 10.6μs
can_minkan 9 12.9ms 8.6μs
can_chi 206 63.3ms 11.7μs
can_riichi 424 289.8ms 121.8μs
is_tenpai 344 189.0ms 11.9μs
  • 優先順位
  • discard
  • pass
  • ron
  • tsumo

@sotetsuk
Copy link
Owner

sotetsuk commented Nov 8, 2023

shogi vs mahjong
{"game": "shogi", "library": "pgx/1dev", "total_steps": 200, "total_sec": 0.08473587036132812, "steps/sec": 2360.2755143384506, "batch_size": 2, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 200, "total_sec": 2.7215688228607178, "steps/sec": 73.48702642388972, "batch_size": 2, "pgx.__version__": "2.0.0"}
{"game": "shogi", "library": "pgx/1dev", "total_steps": 400, "total_sec": 0.08438754081726074, "steps/sec": 4740.036220112051, "batch_size": 4, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 400, "total_sec": 2.679060459136963, "steps/sec": 149.30607431265537, "batch_size": 4, "pgx.__version__": "2.0.0"}
{"game": "shogi", "library": "pgx/1dev", "total_steps": 800, "total_sec": 0.08239483833312988, "steps/sec": 9709.346072936349, "batch_size": 8, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 800, "total_sec": 2.8103582859039307, "steps/sec": 284.6612134874774, "batch_size": 8, "pgx.__version__": "2.0.0"}
{"game": "shogi", "library": "pgx/1dev", "total_steps": 1600, "total_sec": 0.10672569274902344, "steps/sec": 14991.70404789563, "batch_size": 16, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 1600, "total_sec": 2.893850564956665, "steps/sec": 552.896552218466, "batch_size": 16, "pgx.__version__": "2.0.0"}
{"game": "shogi", "library": "pgx/1dev", "total_steps": 3200, "total_sec": 0.1228034496307373, "steps/sec": 26057.899917487746, "batch_size": 32, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 3200, "total_sec": 2.9354007244110107, "steps/sec": 1090.1407679668953, "batch_size": 32, "pgx.__version__": "2.0.0"}
{"game": "shogi", "library": "pgx/1dev", "total_steps": 6400, "total_sec": 0.14135956764221191, "steps/sec": 45274.614988910536, "batch_size": 64, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 6400, "total_sec": 3.0983378887176514, "steps/sec": 2065.6236439883096, "batch_size": 64, "pgx.__version__": "2.0.0"}
{"game": "shogi", "library": "pgx/1dev", "total_steps": 12800, "total_sec": 0.19550633430480957, "steps/sec": 65471.024483758185, "batch_size": 128, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 12800, "total_sec": 3.35774564743042, "steps/sec": 3812.0814808576847, "batch_size": 128, "pgx.__version__": "2.0.0"}
{"game": "shogi", "library": "pgx/1dev", "total_steps": 25600, "total_sec": 0.35773587226867676, "steps/sec": 71561.1767912757, "batch_size": 256, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 25600, "total_sec": 4.034330606460571, "steps/sec": 6345.538454137644, "batch_size": 256, "pgx.__version__": "2.0.0"}
{"game": "shogi", "library": "pgx/1dev", "total_steps": 51200, "total_sec": 0.6315653324127197, "steps/sec": 81068.41425954246, "batch_size": 512, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 51200, "total_sec": 4.863647699356079, "steps/sec": 10527.078268185132, "batch_size": 512, "pgx.__version__": "2.0.0"}
{"game": "shogi", "library": "pgx/1dev", "total_steps": 102400, "total_sec": 1.1985838413238525, "steps/sec": 85434.15693548627, "batch_size": 1024, "pgx.__version__": "2.0.0"}
{"game": "mahjong", "library": "pgx/1dev", "total_steps": 102400, "total_sec": 7.006359338760376, "steps/sec": 14615.293770832695, "batch_size": 1024, "pgx.__version__": "2.0.0"}
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants