Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 55 additions & 119 deletions NeoML/Python/neoml/Dnn/Split.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,48 @@
import numpy


class SplitChannels(Layer):
class SplitLayer(Layer):
"""The base (abstract) class for a split layer.
"""
def __init__(self, classname, input_layer, sizes, name):
assert hasattr(PythonWrapper, classname), 'Incorrect split layer specified: ' + classname

if type(input_layer) is getattr(PythonWrapper, classname):
super().__init__(input_layer)
return

layers, outputs = check_input_layers(input_layer, 1)

internal = getattr(PythonWrapper, classname)(str(name), layers[0], int(outputs[0]), self.__sizes_to_array(sizes))
super().__init__(internal)

@property
def output_sizes(self):
"""
"""
return self._internal.get_output_counts()

@output_sizes.setter
def output_sizes(self, value):
"""
"""
self._internal.set_output_counts(self.__sizes_to_array(value))

@staticmethod
def __sizes_to_array(sizes) -> numpy.ndarray:
sizes = numpy.array(sizes, dtype=numpy.int32)
if sizes.ndim != 1 or sizes.size > 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you have the upper limit of 3 here? It should be up to 31...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it was in the original layer. But that's a mistake! It should be 31. (MathEngine can handle split/merge with up to 32 parts).

raise ValueError('The `sizes` must be a one-dimentional sequence containing not more than 3 elements.')

if numpy.any(sizes < 0):
raise ValueError('The `sizes` must contain only positive values.')

return sizes

# ----------------------------------------------------------------------------------------------------------------------


class SplitChannels(SplitLayer):
"""The layer that splits an input blob along the Channels dimension.

:param input_layer: The input layer and the number of its output. If no number
Expand Down Expand Up @@ -49,27 +90,12 @@ class SplitChannels(Layer):
- all other dimensions are the same as for the input
"""
def __init__(self, input_layer, sizes, name=None):
if type(input_layer) is PythonWrapper.SplitChannels:
super().__init__(input_layer)
return

layers, outputs = check_input_layers(input_layer, 1)

s = numpy.array(sizes, dtype=numpy.int32, copy=False)

if s.size > 3:
raise ValueError('The `sizes` must contain not more than 3 elements.')

if numpy.any(s < 0):
raise ValueError('The `sizes` must contain only positive values.')

internal = PythonWrapper.SplitChannels(str(name), layers[0], int(outputs[0]), s)
super().__init__(internal)
super().__init__("SplitChannels", input_layer, sizes, name)

# ----------------------------------------------------------------------------------------------------------------------


class SplitDepth(Layer):
class SplitDepth(SplitLayer):
"""The layer that splits an input blob along the Depth dimension.

:param input_layer: The input layer and the number of its output. If no number
Expand Down Expand Up @@ -98,27 +124,12 @@ class SplitDepth(Layer):
- all other dimensions are the same as for the input
"""
def __init__(self, input_layer, sizes, name=None):
if type(input_layer) is PythonWrapper.SplitDepth:
super().__init__(input_layer)
return

layers, outputs = check_input_layers(input_layer, 1)

s = numpy.array(sizes, dtype=numpy.int32, copy=False)

if s.size > 3:
raise ValueError('The `sizes` must contain not more than 3 elements.')

if numpy.any(s < 0):
raise ValueError('The `sizes` must contain only positive values.')

internal = PythonWrapper.SplitDepth(str(name), layers[0], int(outputs[0]), s)
super().__init__(internal)
super().__init__("SplitDepth", input_layer, sizes, name)

# ----------------------------------------------------------------------------------------------------------------------


class SplitWidth(Layer):
class SplitWidth(SplitLayer):
"""The layer that splits an input blob along the Width dimension.

:param input_layer: The input layer and the number of its output. If no number
Expand Down Expand Up @@ -147,27 +158,12 @@ class SplitWidth(Layer):
- all other dimensions are the same as for the input
"""
def __init__(self, input_layer, sizes, name=None):
if type(input_layer) is PythonWrapper.SplitWidth:
super().__init__(input_layer)
return

layers, outputs = check_input_layers(input_layer, 1)

s = numpy.array(sizes, dtype=numpy.int32, copy=False)

if s.size > 3:
raise ValueError('The `sizes` must contain not more than 3 elements.')

if numpy.any(s < 0):
raise ValueError('The `sizes` must contain only positive values.')

internal = PythonWrapper.SplitWidth(str(name), layers[0], int(outputs[0]), s)
super().__init__(internal)
super().__init__("SplitWidth", input_layer, sizes, name)

# ----------------------------------------------------------------------------------------------------------------------


class SplitHeight(Layer):
class SplitHeight(SplitLayer):
"""The layer that splits an input blob along the Height dimension.

:param input_layer: The input layer and the number of its output. If no number
Expand Down Expand Up @@ -196,27 +192,12 @@ class SplitHeight(Layer):
- all other dimensions are the same as for the input
"""
def __init__(self, input_layer, sizes, name=None):
if type(input_layer) is PythonWrapper.SplitHeight:
super().__init__(input_layer)
return

layers, outputs = check_input_layers(input_layer, 1)

s = numpy.array(sizes, dtype=numpy.int32, copy=False)

if s.size > 3:
raise ValueError('The `sizes` must contain not more than 3 elements.')

if numpy.any(s < 0):
raise ValueError('The `sizes` must contain only positive values.')

internal = PythonWrapper.SplitHeight(str(name), layers[0], int(outputs[0]), s)
super().__init__(internal)
super().__init__("SplitHeight", input_layer, sizes, name)

# ----------------------------------------------------------------------------------------------------------------------


class SplitListSize(Layer):
class SplitListSize(SplitLayer):
"""The layer that splits an input blob along the ListSize dimension.

:param input_layer: The input layer and the number of its output. If no number
Expand Down Expand Up @@ -245,27 +226,12 @@ class SplitListSize(Layer):
- all other dimensions are the same as for the input
"""
def __init__(self, input_layer, sizes, name=None):
if type(input_layer) is PythonWrapper.SplitListSize:
super().__init__(input_layer)
return

layers, outputs = check_input_layers(input_layer, 1)

s = numpy.array(sizes, dtype=numpy.int32, copy=False)

if s.size > 3:
raise ValueError('The `sizes` must contain not more than 3 elements.')

if numpy.any(s < 0):
raise ValueError('The `sizes` must contain only positive values.')

internal = PythonWrapper.SplitListSize(str(name), layers[0], int(outputs[0]), s)
super().__init__(internal)
super().__init__("SplitListSize", input_layer, sizes, name)

# ----------------------------------------------------------------------------------------------------------------------


class SplitBatchWidth(Layer):
class SplitBatchWidth(SplitLayer):
"""The layer that splits an input blob along the BatchWidth dimension.

:param input_layer: The input layer and the number of its output. If no number
Expand Down Expand Up @@ -294,27 +260,12 @@ class SplitBatchWidth(Layer):
- all other dimensions are the same as for the input
"""
def __init__(self, input_layer, sizes, name=None):
if type(input_layer) is PythonWrapper.SplitBatchWidth:
super().__init__(input_layer)
return

layers, outputs = check_input_layers(input_layer, 1)

s = numpy.array(sizes, dtype=numpy.int32, copy=False)

if s.size > 3:
raise ValueError('The `sizes` must contain not more than 3 elements.')

if numpy.any(s < 0):
raise ValueError('The `sizes` must contain only positive values.')

internal = PythonWrapper.SplitBatchWidth(str(name), layers[0], int(outputs[0]), s)
super().__init__(internal)
super().__init__("SplitBatchWidth", input_layer, sizes, name)

# ----------------------------------------------------------------------------------------------------------------------


class SplitBatchLength(Layer):
class SplitBatchLength(SplitLayer):
"""The layer that splits an input blob along the BatchLength dimension.

:param input_layer: The input layer and the number of its output. If no number
Expand Down Expand Up @@ -343,19 +294,4 @@ class SplitBatchLength(Layer):
- all other dimensions are the same as for the input
"""
def __init__(self, input_layer, sizes, name=None):
if type(input_layer) is PythonWrapper.SplitBatchLength:
super().__init__(input_layer)
return

layers, outputs = check_input_layers(input_layer, 1)

s = numpy.array(sizes, dtype=numpy.int32, copy=False)

if s.size > 3:
raise ValueError('The `sizes` must contain not more than 3 elements.')

if numpy.any(s < 0):
raise ValueError('The `sizes` must contain only positive values.')

internal = PythonWrapper.SplitBatchLength(str(name), layers[0], int(outputs[0]), s)
super().__init__(internal)
super().__init__("SplitBatchLength", input_layer, sizes, name)
2 changes: 1 addition & 1 deletion NeoML/Python/neoml/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def check_input_layers(input_layers, layer_count):
layers.append(i._internal)
outputs.append(0)
elif isinstance(i, (list, tuple)) and len(i) == 2 and isinstance(i[0], Layer) and isinstance(i[1], int):
if int(i[1]) < 0 or int(i[1]) >= i[0].output_count():
if int(i[1]) < 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this? Looks like a bug...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ты письмо читал?

raise ValueError('Invalid value `input_layers`.'
' It must be a list of layers or a list of (layer, output).')
layers.append(i[0]._internal)
Expand Down
Loading