Skip to content

Commit

Permalink
Fix tests for pytorch v0.4.1
Browse files Browse the repository at this point in the history
  • Loading branch information
r9y9 committed Aug 11, 2018
1 parent 7f7eb85 commit 5852640
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/test_gantts.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def test_select_streams():

x = torch.arange(0, 63).expand(32, 100, 63)
assert (select_streams(x, static_stream_sizes, streams=[
False, False, False, True]) == x[:, :, -1]).all()
False, False, False, True]).squeeze(-1) == x[:, :, -1]).all()
assert (select_streams(x, static_stream_sizes, streams=[
False, False, True, False]) == x[:, :, -2]).all()
False, False, True, False]).squeeze(-1) == x[:, :, -2]).all()
assert (select_streams(x, static_stream_sizes, streams=[
False, True, False, False]) == x[:, :, -3]).all()
False, True, False, False]).squeeze(-1) == x[:, :, -3]).all()

# Multiple selects
y = select_streams(x, static_stream_sizes, streams=[
Expand Down Expand Up @@ -154,9 +154,9 @@ def test_multi_stream_mlpg():
bap = y[:, :, 62]

assert (unit_variance_mlpg(R, x[:, :, : 180]) == mgc).data.all()
assert (unit_variance_mlpg(R, x[:, :, 180: 180 + 3]) == lf0).data.all()
assert (unit_variance_mlpg(R, x[:, :, 180: 180 + 3]).squeeze(-1) == lf0).data.all()
assert (x[:, :, 183] == vuv).data.all()
assert (unit_variance_mlpg(R, x[:, :, 184: 184 + 3]) == bap).data.all()
assert (unit_variance_mlpg(R, x[:, :, 184: 184 + 3]).squeeze(-1) == bap).data.all()

static_features = get_static_features(
x, len(windows), stream_sizes, has_dynamic_features)
Expand Down

0 comments on commit 5852640

Please sign in to comment.