Skip to content
This repository was archived by the owner on Apr 10, 2024. It is now read-only.

Commit 4911f12

Browse files
committed
Fixed unit tests according to style guide
1 parent 924378c commit 4911f12

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

lucid/optvis/objectives.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# -*- coding: utf-8 -*-
2+
13
# Copyright 2018 The Lucid Authors. All Rights Reserved.
24
#
35
# Licensed under the Apache License, Version 2.0 (the "License");

tests/optvis/test_objectives.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,17 @@ def f(a):
6565

6666
@pytest.mark.parametrize("cossim_pow", [0, 1, 2])
6767
def test_cossim(cossim_pow):
68+
true_values = [1.0, 2**(0.5)/2, 0.5]
6869
x = np.array([1,1], dtype = np.float32)
6970
y = np.array([1,0], dtype = np.float32)
7071
T = lambda _: tf.constant(x[None, None, None, :])
7172
objective = objectives.direction("dummy", y, cossim_pow=cossim_pow)
72-
obj = objective(T)
73-
sess = tf.Session()
74-
trueval = np.dot(x,y)*(np.dot(x,y)/(np.linalg.norm(x)*np.linalg.norm(y)))**cossim_pow
75-
assert abs(sess.run(obj) - trueval) < 1e-3
73+
objective_t = objective(T)
74+
with tf.Session() as sess:
75+
trueval = np.dot(x,y)*(np.dot(x,y)/(np.linalg.norm(x)*np.linalg.norm(y)))**cossim_pow
76+
print(cossim_pow, trueval)
77+
objective = sess.run(objective_t)
78+
assert abs(objective - true_values[cossim_pow]) < 1e-3
7679

7780

7881
def test_channel(inceptionv1):

0 commit comments

Comments
 (0)