Skip to content

Commit

Permalink
reintroducing session setting in tensorflow
Browse files Browse the repository at this point in the history
  • Loading branch information
ben9809 committed Apr 21, 2024
1 parent 9f63bdb commit bda18ee
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
8 changes: 4 additions & 4 deletions tests/test_neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from keras.callbacks import History
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_regression
from tensorflow.python.keras import backend

from arbitragelab.ml_approach.neural_networks import MultiLayerPerceptron, RecurrentNeuralNetwork, PiSigmaNeuralNetwork

Expand All @@ -28,10 +29,9 @@ def setUp(self):
seed_value = 0
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

#session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
#sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
#tf.compat.v1.keras.backend.set_session(sess)
session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
backend.set_session(sess)

self.seed_value = seed_value

Expand Down
6 changes: 5 additions & 1 deletion tests/test_spread_modeling_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd
import tensorflow as tf

from tensorflow.python.keras import backend
from arbitragelab.cointegration_approach.johansen import JohansenPortfolio
from arbitragelab.ml_approach.regressor_committee import RegressorCommittee
from arbitragelab.util.spread_modeling_helper import SpreadModelingHelper
Expand All @@ -26,6 +26,10 @@ def setUp(self):
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
backend.set_session(sess)

# Collect all contract price data.
project_path = os.path.dirname(__file__)

Expand Down

0 comments on commit bda18ee

Please sign in to comment.