Skip to content

feat(orttraining): add class_axis parameter to CrossEntropyLoss #25807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

MengAiDev
Copy link

Description

  • Introduce class_axis parameter to specify the class dimension
  • Handle negative indexing for class_axis
  • Validate class dimension index
  • Fallback to original behavior if invalid class_axis specified

Motivation and Context

Fix: #25792

- Introduce class_axis parameter to specify the class dimension
- Handle negative indexing for class_axis
- Validate class dimension index
- Fallback to original behavior if invalid class_axis specified
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Training] Potential shape bug: del labels_input.type.tensor_type.shape.dim[-1] assumes channel-last format in loss.py
1 participant