Skip to content
Open
43 changes: 21 additions & 22 deletions src/qonnx/custom_op/general/multithreshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_nodeattr_types(self):
"out_dtype": ("s", True, ""),
"out_scale": ("f", False, 1.0),
"out_bias": ("f", False, 0.0),
"data_layout": ("s", False, "NCHW", {"NCHW", "NHWC"}),
"data_layout": ("s", False, ""),
}

def make_shape_compatible_op(self, model):
Expand Down Expand Up @@ -122,29 +122,28 @@ def execute_node(self, context, graph):
# retrieve attributes if output scaling is used
out_scale = self.get_nodeattr("out_scale")
out_bias = self.get_nodeattr("out_bias")
# transpose input if NHWC data layout is chosen

# Consider the data layout for transposing the input into the format
# accepted by the multithreshold function above, i.e, the channel
# dimension is along the axis with index 1.
data_layout = self.get_nodeattr("data_layout")
if data_layout == "NHWC":
if v.ndim == 4:
# NHWC -> NCHW
v = np.transpose(v, (0, 3, 1, 2))
elif v.ndim == 2:
# no HW dimension means NHWC and NCHW layouts are equivalent
pass
else:
raise Exception("Unknown data_layout and input ndim" " combination for MultiThreshold.")
# calculate output
# If there is no layout annotation, guess based on rank of the
# tensor
if not data_layout and len(v.shape) < 5:
# Maps tensor rank to layout annotation
rank_to_layout = {0: None, 1: "C", 2: "NC", 3: "NWC", 4: "NCHW"}
# Lookup the layout required by this input shape
data_layout = rank_to_layout[len(v.shape)]
# Lookup the index of the channel dimension in the data layout
# Note: Assumes there is at most one "C" which denotes the channel
# dimension
cdim = data_layout.index("C") if "C" in data_layout else 1
# Rearrange the input to the expected (N, C, ...) layout
v = v.swapaxes(cdim, 1)
# Now we can use the multithreshold function to calculate output
output = multithreshold(v, thresholds, out_scale, out_bias)
# setting context according to output
if data_layout == "NHWC":
if output.ndim == 4:
# NCHW -> NHWC
output = np.transpose(output, (0, 2, 3, 1))
elif output.ndim == 2:
# no HW dimension means NHWC and NCHW layouts are equivalent
pass
else:
raise Exception("Unknown data_layout and output ndim" " combination for MultiThreshold.")
# Rearrange the output back to the original layout
output = output.swapaxes(cdim, 1)
context[node.output[0]] = output

def verify_node(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_custom_onnx_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def test_execute_custom_node_multithreshold():
assert (execution_context["out"] == outputs_nhwc).all()
# check the set of allowed values
op_inst = getCustomOp(node_def)
assert op_inst.get_nodeattr_allowed_values("data_layout") == {"NCHW", "NHWC"}
# TODO: Removed this check to generalize the supported data layouts, but do
# we need some other check to verify the validity of data layouts?
# assert op_inst.get_nodeattr_allowed_values("data_layout") == {"NCHW", "NHWC", "NC", "NWC", "NCW"}
# exercise the allowed value checks
# try to set attribute to non-allowed value, should raise an exception
try:
Expand Down
Loading