Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/qonnx/custom_op/channels_last/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from qonnx.custom_op.channels_last.batch_normalization import BatchNormalization
from qonnx.custom_op.channels_last.conv import Conv
from qonnx.custom_op.channels_last.max_pool import MaxPool
from qonnx.custom_op.channels_last.concat import Concat
from qonnx.custom_op.channels_last.resize import Resize


custom_op = dict()

custom_op["Conv"] = Conv
custom_op["MaxPool"] = MaxPool
custom_op["BatchNormalization"] = BatchNormalization
custom_op["Concat"] = Concat
custom_op["Resize"] = Resize

112 changes: 112 additions & 0 deletions src/qonnx/custom_op/channels_last/concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import numpy as np
from onnx import TensorProto, helper

from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp

class Concat(ChannelsLastWrappedOp):
def get_nodeattr_types(self):
"""Returns a dict of permitted attributes for node, where:
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
- dtype indicates which member of the ONNX AttributeProto
will be utilized
- require indicates whether this attribute is required
- default_val indicates the default value that will be used if the
attribute is not set
- <allowed_values> (if specified) indicates that this attribute can only
be set to one of the values in the set <allowed_values>. If not specified,
all values permitted by dtype are allowed.
"""
return {
# axis attribute of Concat layer, default 1
"axis": ("i", True, 1)
}

def make_shape_compatible_op(self, model):
"""Returns a standard ONNX op which is compatible with this CustomOp
for performing shape inference."""

node = self.onnx_node
iname0 = node.input[0]
iname1 = node.input[1]
ishape0 = model.get_tensor_shape(iname0)
ishape1 = model.get_tensor_shape(iname1)
# axis = self.get_nodeattr("axis")
# not sure about what's the shape of inputs, don't know how to check it
# check that ishape0[1] == ishape1[1] and ishape0[2] == ishape1[2]
assert ishape0[1] == ishape1[1], "Input shape [1] has to be the same between the 2 input nodes of concat"
assert ishape0[2] == ishape1[2], "Input shape [2] has to be the same between the 2 input nodes of concat"

# implement tensor with correct shape
output_shape = [1, ishape0[1], ishape0[2], ishape0[3] + ishape1[3]]

# implement tensor with correct shape
values = np.random.randn(*output_shape).astype(np.float32)
return helper.make_node(
"Constant",
inputs=[],
outputs=[self.onnx_node.output[0]],
value=helper.make_tensor(
name="const_tensor",
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.flatten().astype(float),
),
name=self.onnx_node.name,
)

def verify_node(self):
node = self.onnx_node

verification_successful = True
info_messages = []

wrapper_info = ChannelsLastWrappedOp.verify_node(self)
info_messages.extend(wrapper_info)

# verify number of attributes
num_of_attr_min = 1
num_of_attr_max = 1
if (len(node.attribute) >= num_of_attr_min) and len(node.attribute) <= num_of_attr_max:
info_messages.append("The number of attributes is correct")
else:
info_messages.append(
"""The number of attributes is incorrect,
{} should have between {} and {} attributes""".format(
node.op_type, num_of_attr_min, num_of_attr_max
)
)
verification_successful = False

# verify that all necessary attributes exist
try:
self.get_nodeattr("axis")
info_messages.append("All necessary attributes exist")
except Exception:
info_messages.append(
"""The necessary attributes do not exist.
Concat needs the following attributes:
axis"""
)
verification_successful = False

# verify that attributes have the correct datatype.
try:
assert isinstance(self.get_nodeattr("axis"), int)
info_messages.append("All attributes are of the correct type")
except Exception:
info_messages.append("One or more attributes are of the wrong datatype")
verification_successful = False

# verify the number of inputs
if len(node.input) == 2:
info_messages.append("The number of inputs is correct")
else:
info_messages.append("{} needs 2 data input".format(node.op_type))
verification_successful = False

if not verification_successful:
raise RuntimeError(
f"Verification of node {node.name} failed, please check the " f"attached info messages: {info_messages}"
)

return info_messages
133 changes: 133 additions & 0 deletions src/qonnx/custom_op/channels_last/resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import struct
import numpy as np
from onnx import TensorProto, helper

from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp

class Resize(ChannelsLastWrappedOp):
def get_nodeattr_types(self):
"""Returns a dict of permitted attributes for node, where:
ret_dict[attribute_name] = (dtype, require, default_value, <allowed_values>)
- dtype indicates which member of the ONNX AttributeProto
will be utilized
- require indicates whether this attribute is required
- default_val indicates the default value that will be used if the
attribute is not set
- <allowed_values> (if specified) indicates that this attribute can only
be set to one of the values in the set <allowed_values>. If not specified,
all values permitted by dtype are allowed.
"""
return {
"coordinate_transformation_mode": ("s", True, "half_pixel"),
"cubic_coeff_a": ("f", True, -0.75),
"mode": ("s", True, "linear"),
"nearest_mode": ("s", True, "floor")
}

def _get_initializer_from_name(self, model, iname):
for i in model.graph.initializer:
if i.name == iname:
return i

def _compute_fmt(self, tensor_shape):
fmt = "<"
for _ in range(tensor_shape):
fmt += "f"
return fmt

def _compute_resize_output_shape(self, scales, input_shape):
assert len(scales) == len(input_shape)
scales = [int(i) for i in scales]
output_shape = input_shape.copy()
output_shape[1], output_shape[-1] = output_shape[-1], output_shape[1]
for i in range(len(input_shape)):
output_shape[i] *= scales[i]
output_shape[1], output_shape[-1] = output_shape[-1], output_shape[1]
return output_shape

def make_shape_compatible_op(self, model):
"""Returns a standard ONNX op which is compatible with this CustomOp
for performing shape inference."""
node = self.onnx_node
iscalesn = node.input[2]
inode = node.input[0]
inodes = model.get_tensor_shape(inode)
iscalesns = model.get_tensor_shape(iscalesn)
i = self._get_initializer_from_name(model, iscalesn).raw_data
fmt = self._compute_fmt(iscalesns[0])
scales = struct.unpack(fmt, i)

# implement tensor with correct shape
output_shape = self._compute_resize_output_shape(scales, inodes)

# implement tensor with correct shape
values = np.random.randn(*output_shape).astype(np.float32)
return helper.make_node(
"Constant",
inputs=[],
outputs=[self.onnx_node.output[0]],
value=helper.make_tensor(
name="const_tensor",
data_type=TensorProto.FLOAT,
dims=values.shape,
vals=values.flatten().astype(float),
),
name=self.onnx_node.name,
)

def verify_node(self):
node = self.onnx_node

verification_successful = True
info_messages = []

wrapper_info = ChannelsLastWrappedOp.verify_node(self)
info_messages.extend(wrapper_info)

# verify number of attributes
num_of_attr_min = 1
num_of_attr_max = 1
if (len(node.attribute) >= num_of_attr_min) and len(node.attribute) <= num_of_attr_max:
info_messages.append("The number of attributes is correct")
else:
info_messages.append(
"""The number of attributes is incorrect,
{} should have between {} and {} attributes""".format(
node.op_type, num_of_attr_min, num_of_attr_max
)
)
verification_successful = False

# verify that all necessary attributes exist
try:
self.get_nodeattr("axis")
info_messages.append("All necessary attributes exist")
except Exception:
info_messages.append(
"""The necessary attributes do not exist.
Concat needs the following attributes:
axis"""
)
verification_successful = False

# verify that attributes have the correct datatype.
try:
assert isinstance(self.get_nodeattr("axis"), int)
info_messages.append("All attributes are of the correct type")
except Exception:
info_messages.append("One or more attributes are of the wrong datatype")
verification_successful = False

# verify the number of inputs
if len(node.input) == 2:
info_messages.append("The number of inputs is correct")
else:
info_messages.append("{} needs 2 data input".format(node.op_type))
verification_successful = False

if not verification_successful:
raise RuntimeError(
f"Verification of node {node.name} failed, please check the " f"attached info messages: {info_messages}"
)

return info_messages
Loading