Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

network_params_from_definition_string: allow setting CNN strides #355

Open
wants to merge 1 commit into
base: calamari/1.0
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions calamari_ocr/proto/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def set_default_network_params(params):


def network_params_from_definition_string(str, params):
cnn_matcher = re.compile(r"^([\d]+)(:([\d]+)(x([\d]+))?)?$")
cnn_matcher = re.compile(r"^([\d]+)(:([\d]+)x([\d]+))?(:([\d]+)x([\d]+))?$")
db_matcher = re.compile(r"^([\d]+):([\d]+)(:([\d]+)(x([\d]+))?)?$")
concat_matcher = re.compile(r"^([\-\d]+):([\-\d]+)$")
pool_matcher = re.compile(r"^([\d]+)(x([\d]+))?(:([\d]+)x([\d]+))?$")
Expand Down Expand Up @@ -66,7 +66,7 @@ def network_params_from_definition_string(str, params):

match = db_matcher.match(value)
if match is None:
raise Exception("Dilated block structure needs: db=[filters]:[depth>0]:[h]x[w]")
raise Exception("Dilated block structure needs: db=[filters]:[depth>0]:[h]x[w] but got {}".format(value))

match = match.groups()
kernel_size = [2, 2]
Expand All @@ -89,37 +89,38 @@ def network_params_from_definition_string(str, params):

match = cnn_matcher.match(value)
if match is None:
raise Exception("CNN structure needs: cnn=[filters]:[h]x[w] but got {}".format(value))
raise Exception("CNN structure needs: cnn=[filters]:[h]x[w]:[sx]x[sy] but got {}".format(value))

match = match.groups()
kernel_size = [2, 2]
stride = [1, 1]
if match[1] is not None:
kernel_size = [int(match[2])] * 2
if match[3] is not None:
kernel_size = [int(match[2]), int(match[4])]
kernel_size = [int(match[2]), int(match[3])]
if match[4] is not None:
stride = [int(match[5]), int(match[6])]

layer = params.layers.add()
layer.type = LayerParams.CONVOLUTIONAL
layer.filters = int(match[0])
layer.kernel_size.x = kernel_size[0]
layer.kernel_size.y = kernel_size[1]
layer.stride.x = 1
layer.stride.y = 1
layer.stride.x = stride[0]
layer.stride.y = stride[1]
elif label == "tcnn":
if lstm_appeared:
raise Exception("LSTM layers must be placed proceeding to CNN/Pool")

match = cnn_matcher.match(value)
if match is None:
raise Exception("Transposed CNN structure needs: tcnn=[filters]:[sx]x[sy]")
raise Exception("Transposed CNN structure needs: tcnn=[filters]:[sx]x[sy]:[h]x[w] but got {}".format(value))

match = match.groups()
kernel_size = [2, 2]
stride = [2, 2]
if match[1] is not None:
stride = [int(match[2])] * 2
if match[3] is not None:
stride = [int(match[2]), int(match[4])]
stride = [int(match[2]), int(match[3])]
if match[4] is not None:
kernel_size = [int(match[5]), int(match[6])]

layer = params.layers.add()
layer.type = LayerParams.TRANSPOSED_CONVOLUTIONAL
Expand All @@ -133,7 +134,7 @@ def network_params_from_definition_string(str, params):
raise Exception("LSTM layers must be placed proceeding to CNN/Pool")
match = pool_matcher.match(value)
if match is None:
raise Exception("Pool structure needs: pool=[h];[w]")
raise Exception("Pool structure needs: pool=[h]x[w]:[sx]x[sy] but got {}".format(value))

match = match.groups()
kernel_size = [int(match[0])] * 2
Expand Down