Skip to content

feat(orttraining): add class_axis parameter to CrossEntropyLoss #25807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions orttraining/orttraining/python/training/onnxblock/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class CrossEntropyLoss(blocks.Block):
contribute to the input gradient.
"""

def __init__(self, weight=None, reduction: str = "mean", ignore_index: int | None = None):
def __init__(self, weight=None, reduction: str = "mean", ignore_index: int | None = None, class_axis: int = -1):
super().__init__()

if reduction not in ["mean", "sum", "none"]:
Expand All @@ -70,6 +70,7 @@ def __init__(self, weight=None, reduction: str = "mean", ignore_index: int | Non
self._weight = weight
self._reduction = reduction
self._ignore_index = ignore_index
self._class_axis = class_axis

def build(self, scores_input_name: str, labels_name: str = "labels"):
"""Adds a CrossEntropyLoss subgraph on top of an onnx model.
Expand All @@ -92,10 +93,23 @@ def build(self, scores_input_name: str, labels_name: str = "labels"):
labels_input = copy.deepcopy(_graph_utils.get_output_from_output_name(self.base, scores_input_name))
labels_input.name = labels_name
labels_input.type.tensor_type.elem_type = onnx.TensorProto.INT64
# Assumes classes is the last dimension
# Remove the class dimension to create the labels shape
# By default, assumes classes is the last dimension (class_axis=-1)
# e.g., predictions: (num_examples, num_classes) -> labels: (num_examples,)
# or predictions: (batch_size, seq_len, vocab) -> labels: (batch_size, seq_len)
del labels_input.type.tensor_type.shape.dim[-1]
# For channel-last formats, user can specify class_axis to correctly identify the class dimension
if self._class_axis < 0:
# Handle negative indexing
class_dim_index = len(labels_input.type.tensor_type.shape.dim) + self._class_axis
else:
class_dim_index = self._class_axis

# Validate class dimension index
if 0 <= class_dim_index < len(labels_input.type.tensor_type.shape.dim):
del labels_input.type.tensor_type.shape.dim[class_dim_index]
else:
# Fallback to original behavior if invalid class_axis specified
del labels_input.type.tensor_type.shape.dim[-1]
self.base.graph.input.append(labels_input)

loss_node_input_names = [scores_input_name, labels_name]
Expand Down