Skip to content

Commit

Permalink
run test_monitor through pytest; fix the test, add flake8 to bench di…
Browse files Browse the repository at this point in the history
…reectory - like PR 891 (openai#921)
  • Loading branch information
pzhokhov authored May 31, 2019
1 parent ff8d36a commit 1c872ca
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 28 deletions.
1 change: 1 addition & 0 deletions baselines/bench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# flake8: noqa F403
from baselines.bench.benchmarks import *
from baselines.bench.monitor import *
1 change: 0 additions & 1 deletion baselines/bench/benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import os.path as osp
import os
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

Expand Down
26 changes: 0 additions & 26 deletions baselines/bench/monitor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
__all__ = ['Monitor', 'get_monitor_files', 'load_results']

import gym
from gym.core import Wrapper
import time
from glob import glob
import csv
import os.path as osp
import json
import numpy as np

class Monitor(Wrapper):
EXT = "monitor.csv"
Expand Down Expand Up @@ -162,27 +160,3 @@ def load_results(dir):
df['t'] -= min(header['t_start'] for header in headers)
df.headers = headers # HACK to preserve backwards compatibility
return df

def test_monitor():
env = gym.make("CartPole-v1")
env.seed(0)
mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4()
menv = Monitor(env, mon_file)
menv.reset()
for _ in range(1000):
_, _, done, _ = menv.step(0)
if done:
menv.reset()

f = open(mon_file, 'rt')

firstline = f.readline()
assert firstline.startswith('#')
metadata = json.loads(firstline[1:])
assert metadata['env_id'] == "CartPole-v1"
assert set(metadata.keys()) == {'env_id', 'gym_version', 't_start'}, "Incorrect keys in monitor metadata"

last_logline = pandas.read_csv(f, index_col=None)
assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline"
f.close()
os.remove(mon_file)
31 changes: 31 additions & 0 deletions baselines/bench/test_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from .monitor import Monitor
import gym
import json

def test_monitor():
import pandas
import os
import uuid

env = gym.make("CartPole-v1")
env.seed(0)
mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4()
menv = Monitor(env, mon_file)
menv.reset()
for _ in range(1000):
_, _, done, _ = menv.step(0)
if done:
menv.reset()

f = open(mon_file, 'rt')

firstline = f.readline()
assert firstline.startswith('#')
metadata = json.loads(firstline[1:])
assert metadata['env_id'] == "CartPole-v1"
assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata"

last_logline = pandas.read_csv(f, index_col=None)
assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline"
f.close()
os.remove(mon_file)
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ exclude =
.git,
__pycache__,
baselines/ppo1,
baselines/bench,

0 comments on commit 1c872ca

Please sign in to comment.