Skip to content

Commit

Permalink
Remove Dropout from ENAS Trial container
Browse files Browse the repository at this point in the history
Signed-off-by: Andrey Velichkevich <[email protected]>
  • Loading branch information
andreyvelich committed Nov 29, 2024
1 parent 2b41ae6 commit d1bd575
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import json

from keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input
from keras.layers import Dense, GlobalAveragePooling2D, Input
from keras.models import Model
from op_library import concat, conv, dw_conv, reduction, sp_conv

Expand Down Expand Up @@ -67,8 +67,13 @@ def build_model(self):
# Final Layer
# Global Average Pooling, then Fully connected with softmax.
avgpooled = GlobalAveragePooling2D()(all_layers[self.num_layers])
dropped = Dropout(0.4)(avgpooled)
logits = Dense(units=self.output_size, activation="softmax")(dropped)

# TODO (andreyvelich): Currently, Dropout layer fails in distributed training.
# Error: creating distributed tf.Variable with aggregation=MEAN
# and a non-floating dtype is not supported, please use a different aggregation or dtype
# dropped = Dropout(0.4)(avgpooled)

logits = Dense(units=self.output_size, activation="softmax")(avgpooled)

# Encapsulate the model
self.model = Model(inputs=input_layer, outputs=logits)
Expand Down
3 changes: 3 additions & 0 deletions examples/v1beta1/trial-images/enas-cnn-cifar10/RunTrial.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@
num_physical_cpu = len(tf.config.experimental.list_physical_devices("CPU"))
devices = ["/cpu:" + str(j) for j in range(num_physical_cpu)]

print(f">>> Using devices: {devices}")

strategy = tf.distribute.MirroredStrategy(devices)
with strategy.scope():
print("Setup TensorFlow distributed training")
test_model = constructor.build_model()
test_model.summary()
test_model.compile(
Expand Down

0 comments on commit d1bd575

Please sign in to comment.