-
Notifications
You must be signed in to change notification settings - Fork 53
Add LanguageReward for training models to think in target language #515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
casteryh
wants to merge
24
commits into
main
Choose a base branch
from
language-reward-feature
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
033fa60
Add LanguageReward for training models to think in target language
casteryh b12ed15
Update system prompt to instruct model to think in Japanese
casteryh b15f171
Add fallback reward for correct language without thinking blocks
casteryh a2c0237
Add debug logging and troubleshooting guide for LanguageReward
casteryh afca75c
Add debug printing to LanguageReward and strengthen system prompt
casteryh 2625b28
Refactor to use configurable Japanese tags <思考> instead of English <t…
casteryh 1a4d5fb
Remove old debug code from main.py
casteryh 4e87a4d
Weaken system prompt to rely more on RL rewards
casteryh abb653e
Remove sandbox config and reference apps/grpo configs instead
casteryh 7b4829c
Simplify LanguageReward logic to focus on language detection only
casteryh 0ed798c
Add langid to dev dependencies for CI
casteryh 5a3193e
Remove debug script
casteryh 93a65b2
Clarify why English training won't work in TROUBLESHOOTING
casteryh f72be7f
Add unit test for ThinkingReward custom tag
casteryh 6186f9f
Bump LanguageReward match_reward to 2.0
casteryh c640d37
Set KL divergence coefficient to zero in loss function
casteryh 7fde86d
Change KL divergence coefficient to 1e-3
casteryh ffb6c43
Change KL divergence coefficient to 1e-4
casteryh 7ffa20e
Enable multi-epoch training in sandbox/grpo_language app
casteryh 1bf3cca
Fix recursive endpoint call - use while loop instead
casteryh 7758b48
Simplify multi-epoch fix - use return next() instead of while loop
casteryh f71dbb6
fix
casteryh 735af9a
change logging
casteryh ef39e46
git mv
casteryh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,6 +47,7 @@ dev = [ | |
| "anyio", | ||
| "pytest-asyncio", | ||
| "multiprocess", | ||
| "langid", | ||
| ] | ||
| docs = [ | ||
| "sphinx==7.2.6", | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,15 +57,28 @@ def _to_float(self, text: str) -> float | None: | |
|
|
||
|
|
||
| class ThinkingReward: | ||
| """Reward class for evaluating use of <think> tags in reasoning.""" | ||
| """Reward class for evaluating use of thinking tags in reasoning. | ||
|
|
||
| def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0): | ||
| Args: | ||
| partial_reward: Reward for partial tag usage (incomplete/malformed) | ||
| full_reward: Reward for well-formed thinking blocks with content | ||
| tag: Tag name to use (default "think", can use "思考" for Japanese, etc.) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, partial_reward: float = 0.2, full_reward: float = 1.0, tag: str = "think" | ||
| ): | ||
| self.partial_reward = partial_reward | ||
| self.full_reward = full_reward | ||
| self.tag = tag | ||
| # Build regex patterns for the specified tag | ||
| self._THINK_BLOCK_RE = re.compile( | ||
| r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL | ||
| rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", | ||
| re.IGNORECASE | re.DOTALL, | ||
| ) | ||
| self._THINK_TAG_ATTEMPT_RE = re.compile( | ||
| rf"<\s*/?\s*{re.escape(tag)}\s*>", re.IGNORECASE | ||
| ) | ||
| self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE) | ||
|
|
||
| def __call__(self, prompt: str, response: str, target: str | None = None) -> float: | ||
| """Compute thinking reward.""" | ||
|
|
@@ -80,3 +93,132 @@ def __call__(self, prompt: str, response: str, target: str | None = None) -> flo | |
| elif has_attempt: | ||
| return self.partial_reward | ||
| return 0.0 | ||
|
|
||
|
|
||
| class LanguageReward: | ||
| """Reward class for evaluating the language used in responses. | ||
|
|
||
| This reward uses langid to detect the language and rewards responses that use | ||
| the target language. The detection strategy depends on the format: | ||
| - If exactly one thinking block: detect language of the block content | ||
| - Otherwise (no blocks or multiple blocks): detect language of whole response | ||
|
|
||
| Note: Format enforcement (single vs multiple blocks) is handled by ThinkingReward. | ||
| This reward focuses purely on language detection. | ||
|
|
||
| Args: | ||
| target_language: ISO 639-1 language code (e.g., 'en', 'ja', 'zh', 'es') | ||
| match_reward: Reward when detected language matches target (default: 1.0) | ||
| no_match_reward: Reward when language doesn't match (default: 0.0) | ||
| tag: Tag name to use (default "思考" for multilingual, can use "think", etc.) | ||
| debug: If True, print debug samples showing model outputs and detected language | ||
| debug_sample_rate: Fraction of calls to debug (e.g., 0.1 = 10% of calls) | ||
|
|
||
| Note: Requires langid to be installed. Install with: pip install langid | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| target_language: str = "en", | ||
| match_reward: float = 1.0, | ||
| no_match_reward: float = 0.0, | ||
| tag: str = "思考", | ||
| debug: bool = False, | ||
| debug_sample_rate: float = 0.1, | ||
| ): | ||
| self.target_language = target_language | ||
| self.match_reward = match_reward | ||
| self.no_match_reward = no_match_reward | ||
| self.tag = tag | ||
| self.debug = debug | ||
| self.debug_sample_rate = debug_sample_rate | ||
| self._debug_counter = 0 | ||
| # Build regex pattern for the specified tag | ||
| self._THINK_BLOCK_RE = re.compile( | ||
| rf"<\s*{re.escape(tag)}\s*>(.*?)<\s*/\s*{re.escape(tag)}\s*>", re.DOTALL | ||
| ) | ||
|
|
||
| # Lazy import langid with helpful error message | ||
| try: | ||
| import langid | ||
|
|
||
| self._langid = langid | ||
| except ImportError: | ||
| raise ImportError( | ||
| "langid is required for LanguageReward but is not installed. " | ||
| "Please install it with: pip install langid" | ||
| ) from None | ||
|
Comment on lines
+146
to
+150
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you want to surface import error that late into the program? If this is not a "default" app we want everyone to install, is there better option to manage the module dependencies? I think we'll get to this point soon to support different training backend. |
||
|
|
||
| def __call__(self, prompt: str, response: str, target: str | None = None) -> float: | ||
| """Compute language reward based on detected language. | ||
|
|
||
| Detection strategy: | ||
| - If exactly one thinking block: detect language of block content | ||
| - Otherwise: detect language of whole response | ||
|
|
||
| Args: | ||
| prompt: The input prompt (unused but kept for signature consistency) | ||
| response: The model response | ||
| target: Optional target string (unused but kept for signature consistency) | ||
|
|
||
| Returns: | ||
| match_reward if detected language matches target, no_match_reward otherwise | ||
| """ | ||
| # Increment counter for sampling | ||
| self._debug_counter += 1 | ||
| should_debug = ( | ||
| self.debug | ||
| and self.debug_sample_rate > 0 | ||
| and (self._debug_counter % int(1 / self.debug_sample_rate)) == 0 | ||
| ) | ||
|
|
||
| if not response: | ||
| if should_debug: | ||
| print( | ||
| f"\n[LanguageReward] Empty response | Reward: {self.no_match_reward}" | ||
| ) | ||
| return self.no_match_reward | ||
|
|
||
| # Extract all thinking blocks | ||
| matches = self._THINK_BLOCK_RE.findall(response) | ||
|
|
||
| # Determine what text to analyze | ||
| if len(matches) == 1: | ||
| # Single block: detect language of block content only | ||
| text_to_analyze = matches[0].strip() | ||
| detection_mode = "single block" | ||
| else: | ||
| # No blocks or multiple blocks: detect language of whole response | ||
| text_to_analyze = response.strip() | ||
| detection_mode = f"{len(matches)} blocks, using whole response" | ||
|
|
||
| # Remove extra whitespace | ||
| text_to_analyze = re.sub(r"\s+", " ", text_to_analyze).strip() | ||
|
|
||
| if not text_to_analyze: | ||
| if should_debug: | ||
| print(f"\n[LanguageReward] Empty text | Reward: {self.no_match_reward}") | ||
| return self.no_match_reward | ||
|
|
||
| # Detect language using langid | ||
| detected_lang, confidence = self._langid.classify(text_to_analyze) | ||
|
|
||
| # Check if language matches target | ||
| reward = ( | ||
| self.match_reward | ||
| if detected_lang == self.target_language | ||
| else self.no_match_reward | ||
| ) | ||
|
|
||
| if should_debug: | ||
| sample = text_to_analyze[:150].replace("\n", " ") | ||
| match_symbol = "✓" if detected_lang == self.target_language else "✗" | ||
| print( | ||
| f"\n[LanguageReward] Detection mode: {detection_mode}" | ||
| f"\n Target: {self.target_language} | Detected: {detected_lang} | " | ||
| f"Confidence: {confidence:.2f}" | ||
| f"\n Sample: {sample}..." | ||
| f"\n → Reward: {reward} {match_symbol}" | ||
| ) | ||
|
|
||
| return reward | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # GRPO with Language Reward | ||
|
|
||
| This sandbox app demonstrates using GRPO training with a language reward that encourages the model to think in a specific target language. | ||
|
|
||
| ## Overview | ||
|
|
||
| This app extends the standard GRPO training (from `apps/grpo/`) by adding a `LanguageReward` that evaluates whether the model's thinking (text within `<思考></思考>` tags) is in the target language. | ||
|
|
||
| **Key Insight**: Uses Japanese tags `<思考>` (shikō = "thinking") instead of English `<think>` tags to break the model's association between thinking tags and English language. This helps encourage multilingual thinking. | ||
|
|
||
| ## Key Features | ||
|
|
||
| - **Multi-objective training**: Combines three rewards: | ||
| - `MathReward`: Evaluates correctness of math answers | ||
| - `ThinkingReward`: Encourages use of `<思考>` tags | ||
| - `LanguageReward`: Rewards thinking in target language (Japanese by default) | ||
|
|
||
| - **Japanese thinking tags**: Uses `<思考>` instead of `<think>` to encourage non-English reasoning | ||
|
|
||
| - **Language detection**: Uses `langid` to detect the language of thinking blocks | ||
|
|
||
| - **Configurable target language**: While this app defaults to Japanese (`ja`), the `LanguageReward` can be configured for any ISO 639-1 language code | ||
|
|
||
| - **Configurable tags**: Both rewards support custom tag names via the `tag` parameter | ||
|
|
||
| ## Requirements | ||
|
|
||
| Before running this app, install the required language detection library: | ||
|
|
||
| ```bash | ||
| pip install langid | ||
| ``` | ||
|
|
||
| ## Usage | ||
|
|
||
| ```bash | ||
| python -m sandbox.grpo_language.main --config apps/grpo/qwen3_1_7b.yaml | ||
| ``` | ||
|
|
||
| You can use any of the config files from `apps/grpo/` (e.g., `qwen3_1_7b.yaml`, `qwen3_8b.yaml`, `qwen3_32b.yaml`). | ||
|
|
||
| ## How It Works | ||
|
|
||
| 1. The model receives a math problem and is instructed to use `<思考>` tags for reasoning | ||
| 2. During training, the model generates responses with thinking blocks | ||
| 3. Three rewards are computed: | ||
| - **MathReward**: Did it get the right answer? | ||
| - **ThinkingReward**: Did it use `<思考>` tags properly? (single block = full reward, multiple blocks = partial reward) | ||
| - **LanguageReward**: Did it use the target language? Detection strategy: | ||
| - If exactly one thinking block: detect language of block content only | ||
| - Otherwise (no blocks or multiple blocks): detect language of whole response | ||
| - Returns match_reward (1.0) if detected language matches target, no_match_reward (0.0) otherwise | ||
| 4. The model is trained to maximize all three rewards | ||
|
|
||
| **Note**: ThinkingReward enforces format (single vs multiple blocks), while LanguageReward focuses purely on language detection. This separation of concerns allows each reward to specialize in one aspect of the desired behavior. | ||
|
|
||
| ## Configuration | ||
|
|
||
| ### Target Language | ||
|
|
||
| The target language is configured as Japanese in `main.py`: | ||
|
|
||
| ```python | ||
| LanguageReward(target_language="ja", tag="思考") | ||
| ThinkingReward(tag="思考") | ||
| ``` | ||
|
|
||
| To use a different language: | ||
| 1. Change `target_language` to the appropriate ISO 639-1 code: | ||
| - English: `"en"` | ||
| - Chinese: `"zh"` | ||
| - Spanish: `"es"` | ||
| - French: `"fr"` | ||
| - etc. | ||
|
|
||
| ## Expected Behavior | ||
|
|
||
| Over the course of training, the model should learn to: | ||
| 1. Solve math problems correctly | ||
| 2. Use `<思考></思考>` tags for its reasoning | ||
| 3. Write its thinking in Japanese (or the configured target language) | ||
|
|
||
| ## Metrics | ||
|
|
||
| The following metrics are logged to W&B: | ||
| - `reward/evaluate_response/avg_LanguageReward_reward`: Average language reward | ||
| - `reward/evaluate_response/avg_MathReward_reward`: Average math reward | ||
| - `reward/evaluate_response/avg_ThinkingReward_reward`: Average thinking reward | ||
| - `reward/evaluate_response/avg_total_reward`: Average of all rewards | ||
|
|
||
| ## Differences from Standard GRPO | ||
|
|
||
| This is a modified version of `apps/grpo/main.py` with: | ||
| 1. Added import: `from forge.data.rewards import LanguageReward` | ||
| 2. Modified reward functions list to include `LanguageReward(target_language="ja")` | ||
| 3. Updated config to use different W&B group name | ||
|
|
||
| All other training logic remains the same. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to match default value for
tag