Skip to content

Commit

Permalink
Merge pull request #128 from Oneflow-Inc/add_assertation_for_mobilenetv2
Browse files Browse the repository at this point in the history
add assertation and remove transpose for mobilenet
  • Loading branch information
Flowingsun007 authored Sep 14, 2020
2 parents d55b671 + e76857e commit 5c2b305
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions Classification/cnns/mobilenet_v2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ def _relu6(data, prefix):


def mobilenet_unit(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, data_format="NCHW", if_act=True, use_bias=False, prefix=''):
conv = flow.layers.conv2d(inputs=data, filters=num_filter, kernel_size=kernel, strides=stride, padding=pad, data_format=data_format, dilation_rate=1, groups=num_group, activation=None, use_bias=use_bias, kernel_initializer=_get_initializer("weight"), bias_initializer=_get_initializer("bias"), kernel_regularizer=_get_regularizer("weight"), bias_regularizer=_get_regularizer("bias"), name=prefix)
conv = flow.layers.conv2d(inputs=data, filters=num_filter, kernel_size=kernel, strides=stride,
padding=pad, data_format=data_format, dilation_rate=1, groups=num_group, activation=None,
use_bias=use_bias, kernel_initializer=_get_initializer("weight"),
bias_initializer=_get_initializer("bias"), kernel_regularizer=_get_regularizer("weight"),
bias_regularizer=_get_regularizer("bias"), name=prefix)
bn = _batch_norm(conv, axis=1, momentum=0.9, epsilon=1e-5, name='%s-BatchNorm'%prefix)
if if_act:
act = _relu6(bn, prefix)
Expand Down Expand Up @@ -156,11 +160,9 @@ def __init__(self, data_wh, multiplier, **kargs):
else:
self.config_map=MNETV2_CONFIGS_MAP[(224, 224)]

def build_network(self, input_data, need_transpose, data_format, class_num=1000, prefix="", **configs):
def build_network(self, input_data, data_format, class_num=1000, prefix="", **configs):
self.config_map.update(configs)

if need_transpose:
input_data = flow.transpose(input_data, name="transpose", perm=[0, 3, 1, 2])
first_c = int(round(self.config_map['firstconv_filter_num']*self.multiplier))
first_layer = mobilenet_unit(
data=input_data,
Expand Down Expand Up @@ -233,11 +235,13 @@ def build_network(self, input_data, need_transpose, data_format, class_num=1000,
)
return fc

def __call__(self, input_data, need_transpose, class_num=1000, prefix = "", **configs):
sym = self.build_network(input_data, need_transpose, class_num=class_num, prefix=prefix, **configs)
def __call__(self, input_data, class_num=1000, prefix = "", **configs):
sym = self.build_network(input_data, class_num=class_num, prefix=prefix, **configs)
return sym

def Mobilenet(input_data, trainable=True, need_transpose=False, training=True, data_format="NCHW", num_classes=1000, multiplier=1.0, prefix = ""):
def Mobilenet(input_data, trainable=True, training=True, channel_last=False, num_classes=1000, multiplier=1.0, prefix = ""):
assert channel_last==False, "Mobilenet does not support channel_last mode, set channel_last=False will be right!"
data_format="NHWC" if channel_last else "NCHW"
mobilenetgen = MobileNetV2((224,224), multiplier=multiplier)
out = mobilenetgen(input_data, need_transpose, data_format=data_format, class_num=num_classes, prefix = "MobilenetV2")
out = mobilenetgen(input_data, data_format=data_format, class_num=num_classes, prefix = "MobilenetV2")
return out

0 comments on commit 5c2b305

Please sign in to comment.