diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index bdcec748..d68e5771 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -238,7 +238,7 @@ def policy_loss( logits=gen_logits, ) assert token_entropies.shape == batch['action_mask'].shape, ( - f'Token entropies shape {token_entropies.shape} does not match action mask shape {batch["action_mask"].shape}.', + f'Token entropies shape {token_entropies.shape} does not match action mask shape {batch['action_mask'].shape}.', ) seq_entropies = utils.get_sequence_entropies( token_entropies=token_entropies, diff --git a/compose_rl/utils/rlvr_utils.py b/compose_rl/utils/rlvr_utils.py index 4d4c888b..cae862d7 100644 --- a/compose_rl/utils/rlvr_utils.py +++ b/compose_rl/utils/rlvr_utils.py @@ -3,6 +3,7 @@ import logging import re +import signal from typing import Any import sympy @@ -70,34 +71,55 @@ def remove_boxed(s: str) -> str: return s.strip('{}') +class timeout: + + def __init__(self, seconds:int=1, error_message:str='Timeout'): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + def is_equiv(x1: str, x2: str) -> bool: """Checks mathematical equivalence between two normalized LaTeX strings.""" try: - try: - parsed_x1 = parse_latex(x1) - parsed_x2 = parse_latex(x2) - except ( - sympy.parsing.latex. # pyright: ignore[reportGeneralTypeIssues] - errors.LaTeXParsingError, - sympy.SympifyError, - TypeError, - ): - log.debug(f"couldn't parse one of {x1} or {x2}") - return False - - try: - diff = parsed_x1 - parsed_x2 # pyright: ignore[reportOptionalOperand] - except TypeError: - log.debug(f"couldn't subtract {x1} and {x2}") - return False - - try: - return sympy.simplify(diff) == 0 - except ValueError: - log.debug( - f'Had some trouble simplifying when comparing {x1} and {x2}', - ) - return False + with timeout(seconds=5): + try: + parsed_x1 = parse_latex(x1) + parsed_x2 = parse_latex(x2) + except ( + sympy.parsing.latex. # pyright: ignore[reportGeneralTypeIssues] + errors.LaTeXParsingError, + sympy.SympifyError, + TypeError, + ): + log.debug(f"couldn't parse one of {x1} or {x2}") + return False + + try: + diff = parsed_x1 - parsed_x2 # pyright: ignore[reportOptionalOperand] + except TypeError: + log.debug(f"couldn't subtract {x1} and {x2}") + return False + + try: + return sympy.simplify(diff) == 0 + except ValueError: + log.debug( + f'Had some trouble simplifying when comparing {x1} and {x2}', + ) + return False + except TimeoutError: + log.debug(f"Timed out comparing {x1} and {x2}") + return False except ImportError as e: log.error(e) raise