From 5852640c61ddc3f094cead3872fa204498ffe86b Mon Sep 17 00:00:00 2001 From: Ryuichi Yamamoto Date: Sat, 11 Aug 2018 23:32:06 +0900 Subject: [PATCH] Fix tests for pytorch v0.4.1 --- tests/test_gantts.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_gantts.py b/tests/test_gantts.py index 9d27616..8287218 100644 --- a/tests/test_gantts.py +++ b/tests/test_gantts.py @@ -69,11 +69,11 @@ def test_select_streams(): x = torch.arange(0, 63).expand(32, 100, 63) assert (select_streams(x, static_stream_sizes, streams=[ - False, False, False, True]) == x[:, :, -1]).all() + False, False, False, True]).squeeze(-1) == x[:, :, -1]).all() assert (select_streams(x, static_stream_sizes, streams=[ - False, False, True, False]) == x[:, :, -2]).all() + False, False, True, False]).squeeze(-1) == x[:, :, -2]).all() assert (select_streams(x, static_stream_sizes, streams=[ - False, True, False, False]) == x[:, :, -3]).all() + False, True, False, False]).squeeze(-1) == x[:, :, -3]).all() # Multiple selects y = select_streams(x, static_stream_sizes, streams=[ @@ -154,9 +154,9 @@ def test_multi_stream_mlpg(): bap = y[:, :, 62] assert (unit_variance_mlpg(R, x[:, :, : 180]) == mgc).data.all() - assert (unit_variance_mlpg(R, x[:, :, 180: 180 + 3]) == lf0).data.all() + assert (unit_variance_mlpg(R, x[:, :, 180: 180 + 3]).squeeze(-1) == lf0).data.all() assert (x[:, :, 183] == vuv).data.all() - assert (unit_variance_mlpg(R, x[:, :, 184: 184 + 3]) == bap).data.all() + assert (unit_variance_mlpg(R, x[:, :, 184: 184 + 3]).squeeze(-1) == bap).data.all() static_features = get_static_features( x, len(windows), stream_sizes, has_dynamic_features)