Skip to content

Commit 5ef1b51

Browse files
committed
Add random reshuffle
1 parent a0989bc commit 5ef1b51

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

graphs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ checkpoint19='experimentOutputs/list_aic=1.0_arity=3_BO=True_CO=True_ET=720_HR=0
77

88
srun --job-name=graphsHeldout$1 --output=data/batch_12_2_2018/graphs --ntasks=1 --mem-per-cpu=5000 --cpus-per-task 1 --time=5:00 --qos=tenenbaum \
99
singularity exec -B /om2 sklearn-container.img \
10-
python graphs.py --checkpoints $checkpoint19 --showEpochs --export data/batch_12_2_2018/test.png \
10+
python graphs.py --checkpoints $checkpoint19 --showEpochs --showTraining --export data/batch_12_2_2018/test.png \
1111
&

taskBatcher.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
3838

3939
class RandomShuffleTaskBatcher:
4040
"""Randomly shuffles the task batch first, and then iterates through task batches of the specified size like DefaultTaskBatcher.
41-
Uses a fixed shuffling across iterations - intended as benchmark comparison to test the task ordering."""
41+
Reshuffles across iterations - intended as benchmark comparison to test the task ordering."""
4242
def __init__(self):
4343
pass
4444

@@ -49,13 +49,19 @@ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
4949
eprint("Task batch size is greater than total number of tasks, aborting.")
5050
assert False
5151

52-
# Shuffles tasks with a set seed across iterations.
52+
# Reshuffles tasks in a fixed way across epochs for reproducibility.
53+
baseSeed = 0
54+
currEpoch = int(int(currIteration * taskBatchSize) / int(len(tasks)))
55+
5356
shuffledTasks = tasks.copy() # Since shuffle works in place.
54-
random.Random(0).shuffle(shuffledTasks)
57+
random.Random(baseSeed + currEpoch).shuffle(shuffledTasks)
58+
59+
shuffledTasksWrap = tasks.copy() # Since shuffle works in place.
60+
random.Random(baseSeed + currEpoch + 1).shuffle(shuffledTasksWrap)
5561

5662
start = (taskBatchSize * currIteration) % len(shuffledTasks)
5763
end = start + taskBatchSize
58-
taskBatch = (shuffledTasks + shuffledTasks)[start:end] # Handle wraparound.
64+
taskBatch = (shuffledTasks + shuffledTasksWrap)[start:end] # Wraparound nicely.
5965

6066
return taskBatch
6167

0 commit comments

Comments
 (0)