Skip to content

Commit

Permalink
Merge pull request #6 from gwenchee/pep8-fixes
Browse files Browse the repository at this point in the history
Pep8 fixes
  • Loading branch information
gwenchee authored Mar 4, 2022
2 parents ff889ef + 98c5de6 commit ac89abd
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 59 deletions.
25 changes: 14 additions & 11 deletions rollo/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@

MPI.pickle.__init__(dill.dumps, dill.loads)
from mpi4py.futures import MPICommExecutor
except:
except BaseException:
warnings.warn(
"Failed to import mpi4py. (Only necessary for parallel method: mpi_evals). \
Please ignore this warning if you are using other parallel methods such \
as multiprocessing and none."
)
as multiprocessing and none.")


class Algorithm(object):
Expand Down Expand Up @@ -103,7 +102,7 @@ def generate(self):

pool = multiprocessing.Pool()
self.toolbox.register("map", pool.map)
except:
except BaseException:
warnings.warn(
"multiprocessing_on_dill failed to import, rollo will run serially."
)
Expand All @@ -115,7 +114,7 @@ def generate(self):
with MPICommExecutor(MPI.COMM_WORLD, root=0) as executor:
pass
sys.exit(0)
except:
except BaseException:
warnings.warn("MPI Failed.")
pass
if self.cp_file:
Expand All @@ -127,14 +126,16 @@ def generate(self):
pop = self.initialize_pop(pop)
self.cp_file = "checkpoint.pkl"
print(self.backend.results["logbook"])
for gen in range(self.backend.results["start_gen"] + 1, self.toolbox.ngen):
for gen in range(
self.backend.results["start_gen"] + 1,
self.toolbox.ngen):
pop = self.apply_algorithm_ngen(pop, gen)
print(self.backend.results["logbook"])
print("rollo Simulation Completed!")
if self.parallel_method == "mpi_evals":
try:
MPI.COMM_WORLD.bcast(False)
except:
except BaseException:
pass
return pop

Expand Down Expand Up @@ -165,7 +166,7 @@ def initialize_pop(self, pop):
MPI.COMM_WORLD.bcast(True)
with MPICommExecutor(MPI.COMM_WORLD, root=0) as executor:
fitnesses = executor.map(self.toolbox.evaluate, list(pop))
except:
except BaseException:
warnings.warn("MPI Failed, rollo will run serially.")
fitnesses = self.toolbox.map(self.toolbox.evaluate, pop)
else:
Expand Down Expand Up @@ -212,10 +213,12 @@ def apply_algorithm_ngen(self, pop, gen):
try:
MPI.COMM_WORLD.bcast(True)
with MPICommExecutor(MPI.COMM_WORLD, root=0) as executor:
fitnesses = executor.map(self.toolbox.evaluate, list(invalids))
except:
fitnesses = executor.map(
self.toolbox.evaluate, list(invalids))
except BaseException:
warnings.warn("MPI Failed, rollo will run serially.")
fitnesses = self.toolbox.map(self.toolbox.evaluate, list(invalids))
fitnesses = self.toolbox.map(
self.toolbox.evaluate, list(invalids))
else:
fitnesses = self.toolbox.map(self.toolbox.evaluate, list(invalids))
# assign fitness values to individuals
Expand Down
4 changes: 2 additions & 2 deletions rollo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def update_backend(self, pop, gen, invalid_ind, rndstate):
self.input_file["evaluators"][solver]["output_script"], "r"
) as file:
evaluator_files[solver + "_output"] = file.read()
except:
except BaseException:
pass
except:
except BaseException:
pass
cp = dict(
input_file=self.input_file,
Expand Down
23 changes: 17 additions & 6 deletions rollo/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(self):
self.input_scripts = {}
self.output_scripts = {}
# Developers can add new solvers to self.eval_dict below
self.eval_dict = {"openmc": OpenMCEvaluation(), "moltres": MoltresEvaluation()}
self.eval_dict = {
"openmc": OpenMCEvaluation(),
"moltres": MoltresEvaluation()}

def add_evaluator(self, solver_name, input_script, output_script):
"""Adds information about an evaluator to the Evaluation class object
Expand All @@ -59,7 +61,7 @@ def add_evaluator(self, solver_name, input_script, output_script):
self.input_scripts[solver_name] = input_script
try:
self.output_scripts[solver_name] = output_script
except:
except BaseException:
pass
return

Expand Down Expand Up @@ -141,7 +143,13 @@ def solver_order(self, input_evaluators):
order[input_evaluators[solver]["order"]] = solver
return order

def get_output_vals(self, output_vals, solver, output_dict, control_vars, path):
def get_output_vals(
self,
output_vals,
solver,
output_dict,
control_vars,
path):
"""Returns a populated list with output values for each solver
Parameters
Expand All @@ -166,7 +174,8 @@ def get_output_vals(self, output_vals, solver, output_dict, control_vars, path):
"""

if self.output_scripts[solver]:
# copy rendered output script into a new file in the particular solver's run
# copy rendered output script into a new file in the particular
# solver's run
shutil.copyfile(
self.output_scripts[solver], path + "/" + solver + "_output.py"
)
Expand Down Expand Up @@ -273,7 +282,8 @@ def render_jinja_template_python(self, script, control_vars_solver):
template = nativetypes.NativeTemplate(imported_script)
render_str = "template.render("
for inp in control_vars_solver:
render_str += "**{'" + inp + "':" + str(control_vars_solver[inp]) + "},"
render_str += "**{'" + inp + "':" + \
str(control_vars_solver[inp]) + "},"
render_str += ")"
rendered_template = eval(render_str)
return rendered_template
Expand Down Expand Up @@ -301,7 +311,8 @@ def openmc_run(self, rendered_openmc_script):
f.write(rendered_openmc_script)
f.close()
with open("output.txt", "wb") as output:
subprocess.call(["python", "-u", "./openmc_input.py"], stdout=output)
subprocess.call(
["python", "-u", "./openmc_input.py"], stdout=output)
return

def moltres_run(self, rendered_moltres_script):
Expand Down
5 changes: 3 additions & 2 deletions rollo/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def execute(self):
iv.validate()
complete_input_dict = iv.input
# organize control variables and output dict
control_dict, output_dict = self.organize_input_output(complete_input_dict)
control_dict, output_dict = self.organize_input_output(
complete_input_dict)
# generate evaluator function
evaluator_fn = self.load_evaluator(
control_dict, output_dict, complete_input_dict
Expand Down Expand Up @@ -185,7 +186,7 @@ def load_evaluator(self, control_dict, output_dict, input_dict):
solver_dict = input_evaluators[solver]
try:
output_script = solver_dict["output_script"]
except:
except BaseException:
output_script = None
evaluator.add_evaluator(
solver_name=solver,
Expand Down
19 changes: 13 additions & 6 deletions rollo/input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ def add_all_defaults(self):
input_evaluators[solver], "keep_files", "all"
)
input_algorithm = input_dict["algorithm"]
input_algorithm = self.default_check(input_algorithm, "objective", "min")
input_algorithm = self.default_check(
input_algorithm, "objective", "min")
input_algorithm = self.default_check(input_algorithm, "weight", [1.0])
input_algorithm = self.default_check(input_algorithm, "pop_size", 60)
input_algorithm = self.default_check(input_algorithm, "generations", 10)
input_algorithm = self.default_check(
input_algorithm, "generations", 10)
input_algorithm = self.default_check(
input_algorithm, "mutation_probability", 0.23
)
Expand Down Expand Up @@ -319,7 +321,10 @@ def validate_constraints(self, input_constraints, input_evaluators):
for evaluator in input_evaluators:
allowed_constraints += input_evaluators[evaluator]["outputs"]
for constraint in input_constraints:
self.validate_in_list(constraint, allowed_constraints, "Constraints")
self.validate_in_list(
constraint,
allowed_constraints,
"Constraints")
# schema validation
schema_constraints = {"type": "object", "properties": {}}
for constraint in input_constraints:
Expand Down Expand Up @@ -386,8 +391,8 @@ def validate_ctrl_vars(self, input_ctrl_vars):
# key validation
for var in variables:
self.validate_correct_keys(
input_ctrl_vars[var], ["min", "max"], [], "control variable: " + var
)
input_ctrl_vars[var], [
"min", "max"], [], "control variable: " + var)

# validate special control variables
# add validation here if developer adds new special input variable
Expand All @@ -410,7 +415,9 @@ def validate_ctrl_vars(self, input_ctrl_vars):
"height": {"type": "number"},
},
}
validate(instance=input_ctrl_vars_poly, schema=schema_ctrl_vars_poly)
validate(
instance=input_ctrl_vars_poly,
schema=schema_ctrl_vars_poly)
# key validation
self.validate_correct_keys(
input_ctrl_vars_poly,
Expand Down
52 changes: 42 additions & 10 deletions rollo/toolbox_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ class ToolboxGenerator(object):
"""The ToolboxGenerator class initializes DEAP's toolbox and creator
modules with genetic algorithm hyperparameters defined in the input file."""

def setup(self, evaluator_fn, input_algorithm, input_ctrl_vars, control_dict):
def setup(
self,
evaluator_fn,
input_algorithm,
input_ctrl_vars,
control_dict):
"""sets up DEAP toolbox with user-defined inputs
Parameters
Expand Down Expand Up @@ -47,11 +52,22 @@ def setup(self, evaluator_fn, input_algorithm, input_ctrl_vars, control_dict):
for var in input_ctrl_vars:
if var not in special_control_vars:
var_dict = input_ctrl_vars[var]
toolbox.register(var, random.uniform, var_dict["min"], var_dict["max"])
toolbox.register(
var,
random.uniform,
var_dict["min"],
var_dict["max"])
toolbox.register(
"individual", self.individual_values, input_ctrl_vars, control_dict, toolbox
)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
"individual",
self.individual_values,
input_ctrl_vars,
control_dict,
toolbox)
toolbox.register(
"population",
tools.initRepeat,
list,
toolbox.individual)
toolbox.register("evaluate", evaluator_fn)
min_list, max_list = self.min_max_list(control_dict, input_ctrl_vars)
toolbox.min_list, toolbox.max_list = min_list, max_list
Expand Down Expand Up @@ -138,8 +154,13 @@ def min_max_list(self, control_dict, input_ctrl_vars):
return min_list, max_list

def add_toolbox_operators(
self, toolbox, selection_dict, mutation_dict, mating_dict, min_list, max_list
):
self,
toolbox,
selection_dict,
mutation_dict,
mating_dict,
min_list,
max_list):
"""Adds selection, mutation, and mating operators to the deap toolbox
Parameters
Expand Down Expand Up @@ -198,12 +219,20 @@ def add_selection_operators(self, toolbox, selection_dict):
tournsize=selection_dict["tournsize"],
)
elif operator == "selNSGA2":
toolbox.register("select", tools.selNSGA2, k=selection_dict["inds"])
toolbox.register(
"select",
tools.selNSGA2,
k=selection_dict["inds"])
elif operator == "selBest":
toolbox.register("select", tools.selBest, k=selection_dict["inds"])
return toolbox

def add_mutation_operators(self, toolbox, mutation_dict, min_list, max_list):
def add_mutation_operators(
self,
toolbox,
mutation_dict,
min_list,
max_list):
"""Adds mutation operator to the deap toolbox
Parameters
Expand Down Expand Up @@ -255,7 +284,10 @@ def add_mating_operators(self, toolbox, mating_dict):
if operator == "cxOnePoint":
toolbox.register("mate", tools.cxOnePoint)
elif operator == "cxUniform":
toolbox.register("mate", tools.cxUniform, indpb=mating_dict["indpb"])
toolbox.register(
"mate",
tools.cxUniform,
indpb=mating_dict["indpb"])
elif operator == "cxBlend":
toolbox.register("mate", tools.cxBlend, alpha=mating_dict["alpha"])
return toolbox
2 changes: 1 addition & 1 deletion tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_initialize_pop():
assert ind.fitness.values[0] < 3
assert ind.fitness.values[0] > 1
assert ind.output[1] == 5
assert type(ind) is creator.Ind
assert isinstance(ind, creator.Ind)
assert ind[0] < 1
assert ind[0] > 0
assert ind[1] > 1
Expand Down
35 changes: 24 additions & 11 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,27 @@ def evaluator_fn(ind):

def test_initialize_new_backend():
b = BackEnd(
"square_checkpoint.pkl", creator, control_dict, output_dict, input_file, 0
)
"square_checkpoint.pkl",
creator,
control_dict,
output_dict,
input_file,
0)
b.initialize_new_backend()
assert b.results["start_gen"] == 0
assert type(b.results["halloffame"]) == tools.HallOfFame
assert type(b.results["logbook"]) == tools.Logbook
assert type(b.results["all"]) == dict
assert isinstance(b.results["halloffame"], tools.HallOfFame)
assert isinstance(b.results["logbook"], tools.Logbook)
assert isinstance(b.results["all"], dict)


def test_ind_naming():
b = BackEnd(
"square_checkpoint.pkl", creator, control_dict, output_dict, input_file, 0
)
"square_checkpoint.pkl",
creator,
control_dict,
output_dict,
input_file,
0)
ind_dict = b.ind_naming()
expected_ind_dict = {
"packing_fraction": 0,
Expand All @@ -78,8 +86,12 @@ def test_ind_naming():

def test_output_naming():
b = BackEnd(
"square_checkpoint.pkl", creator, control_dict, output_dict, input_file, 0
)
"square_checkpoint.pkl",
creator,
control_dict,
output_dict,
input_file,
0)
oup_dict = b.output_naming()
expected_oup_dict = {
"packing_fraction": 0,
Expand Down Expand Up @@ -114,7 +126,7 @@ def test_initialize_checkpoint_backend():
assert ind.fitness.values[0] < 3
assert b.results["start_gen"] == 0
assert b.results["halloffame"].items[0] == max(pop, key=lambda x: x[2])
assert type(b.results["logbook"]) == tools.Logbook
assert isinstance(b.results["logbook"], tools.Logbook)
assert len(b.results["logbook"]) == 1
os.remove("./input_test_files/test_checkpoint.pkl")

Expand Down Expand Up @@ -145,7 +157,8 @@ def test_update_backend():
rndstate = random.getstate()
b.update_backend(new_pop, gen, invalids, rndstate)
pop = b.results["population"]
assert b.results["halloffame"].items[0] == max(pop + new_pop, key=lambda x: x[2])
assert b.results["halloffame"].items[0] == max(
pop + new_pop, key=lambda x: x[2])
assert len(b.results["logbook"]) == 2
bb = BackEnd(
"input_test_files/test_checkpoint.pkl",
Expand Down
Loading

0 comments on commit ac89abd

Please sign in to comment.