diff --git a/orttraining/orttraining/python/training/onnxblock/loss/loss.py b/orttraining/orttraining/python/training/onnxblock/loss/loss.py index e0624c6722519..2b1b2a40eb00c 100644 --- a/orttraining/orttraining/python/training/onnxblock/loss/loss.py +++ b/orttraining/orttraining/python/training/onnxblock/loss/loss.py @@ -61,7 +61,7 @@ class CrossEntropyLoss(blocks.Block): contribute to the input gradient. """ - def __init__(self, weight=None, reduction: str = "mean", ignore_index: int | None = None): + def __init__(self, weight=None, reduction: str = "mean", ignore_index: int | None = None, class_axis: int = -1): super().__init__() if reduction not in ["mean", "sum", "none"]: @@ -70,6 +70,7 @@ def __init__(self, weight=None, reduction: str = "mean", ignore_index: int | Non self._weight = weight self._reduction = reduction self._ignore_index = ignore_index + self._class_axis = class_axis def build(self, scores_input_name: str, labels_name: str = "labels"): """Adds a CrossEntropyLoss subgraph on top of an onnx model. @@ -92,10 +93,23 @@ def build(self, scores_input_name: str, labels_name: str = "labels"): labels_input = copy.deepcopy(_graph_utils.get_output_from_output_name(self.base, scores_input_name)) labels_input.name = labels_name labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64 - # Assumes classes is the last dimension + # Remove the class dimension to create the labels shape + # By default, assumes classes is the last dimension (class_axis=-1) # e.g., predictions: (num_examples, num_classes) -> labels: (num_examples,) # or predictions: (batch_size, seq_len, vocab) -> labels: (batch_size, seq_len) - del labels_input.type.tensor_type.shape.dim[-1] + # For channel-last formats, user can specify class_axis to correctly identify the class dimension + if self._class_axis < 0: + # Handle negative indexing + class_dim_index = len(labels_input.type.tensor_type.shape.dim) + self._class_axis + else: + class_dim_index = self._class_axis + + # Validate class dimension index + if 0 <= class_dim_index < len(labels_input.type.tensor_type.shape.dim): + del labels_input.type.tensor_type.shape.dim[class_dim_index] + else: + # Fallback to original behavior if invalid class_axis specified + del labels_input.type.tensor_type.shape.dim[-1] self.base.graph.input.append(labels_input) loss_node_input_names = [scores_input_name, labels_name]