From cc994a46b7e92601eb1100c04f25e4269dc42969 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 26 Jul 2023 03:15:37 +0000 Subject: [PATCH 1/2] enable inaccurate unittests --- tests/run_all_paddle_ci.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/run_all_paddle_ci.sh b/tests/run_all_paddle_ci.sh index bf2f45192..b826de00a 100644 --- a/tests/run_all_paddle_ci.sh +++ b/tests/run_all_paddle_ci.sh @@ -14,9 +14,6 @@ disabled_tests=( # Because range interrupts networking, Paddle.grad cannot be networked as a standalone API. # CAN BE OPEN AFTER: range is support. ${PADDLE_TEST_BASE}/test_grad.py - ${PADDLE_TEST_BASE}/test_ptb_lm.py # There is accuracy problem of the model in SOT - ${PADDLE_TEST_BASE}/test_ptb_lm_v2.py # There is accuracy problem of the model in SOT - ${PADDLE_TEST_BASE}/test_cycle_gan.py # This test has a precision problem when it reaches the maximum cache size ) for file in ${PADDLE_TEST_BASE}/*.py; do From e39aa195008d0778357ba178a3062b073b6743fc Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 1 Aug 2023 04:46:44 +0000 Subject: [PATCH 2/2] Fix Memory Problem for resume function! --- .../executor/function_graph.py | 33 +++++++- .../executor/opcode_executor.py | 13 +++- .../executor/pycode_generator.py | 75 +++++++++++++++++-- .../executor/variables/basic.py | 2 +- .../instruction_utils/instruction_utils.py | 13 +++- 5 files changed, 122 insertions(+), 14 deletions(-) diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index b778d41a4..fafe6eb7d 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -94,11 +94,20 @@ def __init__(self, frame, **kwargs): self.input_variables = [] # Store variables required within a function self.pycode_gen = PyCodeGen(frame, disable_eval_frame=True) self.side_effects = SideEffects() - self.py_frame = frame self._global_guarded_variables: OrderedSet[VariableBase] = OrderedSet() self._print_variables = [] self.build_strategy = kwargs.get('build_strategy', None) + def clear(self): + self.sir_ctx = None + self.inner_out = None + self.input_variables = None + self.pycode_gen = None + self.side_effects = None + self._global_guarded_variables = None + self._print_variables = None + self.build_strategy = None + @cached_property def _builtins(self): builtins_ = {} @@ -197,13 +206,28 @@ def start_compile_with_name_store(self, ret_vars, to_store_vars): class VariableLoader: def __init__(self, index_for_load, pycode_gen): self._index_for_load = index_for_load + self._save_cnt = {} self._pycode_gen = pycode_gen + def save(self, var): + cnt = self._save_cnt.get(var.id, 0) + self._save_cnt[var.id] = cnt + 1 + def load(self, var): if isinstance(var, DummyVariable): var.reconstruct(self._pycode_gen) return self._pycode_gen.gen_load_fast(self._index_for_load[var.id]) + self.delete(var) + + def delete(self, var): + if isinstance(var, DummyVariable): + return + self._save_cnt[var.id] -= 1 + if self._save_cnt[var.id] == 0: + self._pycode_gen.gen_delete_fast( + self._index_for_load[var.id] + ) # var_id -> local_name mapping index_for_load = {} @@ -223,9 +247,11 @@ def _log_fn(): log_do(4, _log_fn) + loader = VariableLoader(index_for_load, self.pycode_gen) for var in to_store_vars[::-1]: self.pycode_gen.gen_store_fast(index_for_load[var.id]) - return VariableLoader(index_for_load, self.pycode_gen) + loader.save(var) + return loader @event_register("start_compile") def start_compile(self, *ret_vars: VariableBase): @@ -293,6 +319,9 @@ def start_compile(self, *ret_vars: VariableBase): view_tracker(list(ret_vars), tracker_output_path, format="png") + # delete all the resume_xxx locals in f_locals + self.pycode_gen.gen_delete_resume_locals() + def call_paddle_api( self, func: Callable[..., Any], diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 6747a778c..210e63b31 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -232,9 +232,10 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction | None: GuardedFunction | None: The translated code object and its guard function, or None if translation fails. """ simulator = OpcodeExecutor(frame, **kwargs) + code = simulator._code try: log(3, f"OriginCode: {simulator._code}\n") - log_do(3, lambda: dis.dis(simulator._code)) + log_do(3, lambda: dis.dis(code)) new_code, guard_fn = simulator.transform() log(3, f"NewCode: {new_code}\n") log_do(3, lambda: dis.dis(new_code)) @@ -255,6 +256,12 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction | None: return py_codegen.replace_dummy_variable() except Exception as e: raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e + finally: + # collect all the cicle references. + # we gc after transformation but not after cache hit. + import gc + + gc.collect() def tos_op_wrapper(fn: Callable): @@ -1649,6 +1656,7 @@ def _break_graph_in_call( self.pop_n(pop_n) stack_size = len(self._stack) + push_n resume_fn, _ = self._create_resume_fn(index + 1, stack_size) + passing_arg_num = len(resume_input_name) + stack_size if resume_fn: self._graph.pycode_gen.gen_load_object( resume_fn, resume_fn.__code__.co_name @@ -1656,8 +1664,9 @@ def _break_graph_in_call( self._graph.pycode_gen.gen_rot_n(stack_size + 1) for name in resume_input_name: var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_build_list(passing_arg_num) self._graph.pycode_gen.gen_call_function( - argc=resume_fn.__code__.co_argcount, + argc=1, with_eval_frame=True, ) diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index 6c8bb69c1..fe3ad5ad9 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -5,6 +5,8 @@ from __future__ import annotations import dis +import os +import re import sys import types from typing import TYPE_CHECKING @@ -257,20 +259,33 @@ def gen_resume_fn_at(self, index, stack_size=0): return None, OrderedSet() inputs = analysis_inputs(self._instructions, index) fn_name = ResumeFnNameFactory().next() - stack_arg_str = fn_name + '_stack_{}' + stack_args_list = fn_name + '_list' + header = [] + for i in range(stack_size): + header.append(gen_instr('LOAD_FAST', argval=stack_args_list)) + header.append(gen_instr('LOAD_CONST', argval=0)) + header.append(gen_instr('BINARY_SUBSCR', argval=None)) + header.append(gen_instr('LOAD_FAST', argval=stack_args_list)) + header.append(gen_instr('LOAD_CONST', argval=0)) + header.append(gen_instr('DELETE_SUBSCR', argval=None)) + + for name in list(inputs): + header.append(gen_instr('LOAD_FAST', argval=stack_args_list)) + header.append(gen_instr('LOAD_CONST', argval=0)) + header.append(gen_instr('BINARY_SUBSCR', argval=None)) + header.append(gen_instr('STORE_FAST', argval=name)) + self._instructions = ( - [ - gen_instr('LOAD_FAST', argval=stack_arg_str.format(i)) - for i in range(stack_size) - ] + header + [gen_instr('JUMP_ABSOLUTE', jump_to=self._instructions[index])] + self._instructions ) - - self._code_options['co_argcount'] = len(inputs) + stack_size + if 0 not in self._code_options['co_consts']: + self._code_options['co_consts'] += [0] + self._code_options['co_argcount'] = 1 # inputs should be at the front of the co_varnames self._code_options['co_varnames'] = list( - [stack_arg_str.format(i) for i in range(stack_size)] + [stack_args_list] + list(inputs) + [ var_name @@ -292,6 +307,8 @@ def gen_resume_fn_at(self, index, stack_size=0): return fn, inputs def gen_disable_eval_frame(self): + if os.environ.get('CLEAN_CODE', None) is not None: + return self.gen_load_object( paddle.fluid.core.set_eval_frame, "paddle_set_eval_frame_fn" ) @@ -300,6 +317,8 @@ def gen_disable_eval_frame(self): self.gen_store_fast("___old_eval_frame") def gen_enable_eval_frame(self): + if os.environ.get('CLEAN_CODE', None) is not None: + return self.gen_load_object( paddle.fluid.core.set_eval_frame, "paddle_set_eval_frame_fn" ) @@ -481,6 +500,46 @@ def gen_load_fast(self, name): idx = self._code_options["co_varnames"].index(name) self._add_instr("LOAD_FAST", arg=idx, argval=name) + def gen_delete_resume_locals(self): + def dbg_func(): + import inspect + + print("dbg here.") + frame = inspect.currentframe().f_back + code = inspect.currentframe().f_back.f_code + print("locals = ", frame.f_locals) + print("code is = ", code) + import gc + import sys + + gc.collect() + if 'resume_0_stack_0' in frame.f_locals: + print( + "Ref: ", sys.getrefcount(frame.f_locals['resume_0_stack_0']) + ) + print( + "Refers: ", + gc.get_referrers(frame.f_locals['resume_0_stack_0']), + ) + import sys + + sys.xk_args = frame.f_locals['resume_0_stack_0'] + + # self.gen_dbg_function(dbg_func) + resume_local_pattern = "resume_[0-9]+_stack_[0-9]+" + for name in self._code_options['co_varnames']: + if ( + re.match(resume_local_pattern, name) + and name in self._frame.f_locals + ): + self.gen_delete_fast(name) + + def gen_delete_fast(self, name): + if name not in self._code_options["co_varnames"]: + self._code_options["co_varnames"].append(name) + idx = self._code_options["co_varnames"].index(name) + self._add_instr("DELETE_FAST", arg=idx, argval=name) + def gen_load_deref(self, name): if name not in self._code_options["co_cellvars"]: self._code_options["co_cellvars"].append(name) diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index 013fba686..c0ff419c2 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -228,7 +228,7 @@ def __init__( ): super().__init__(graph, tracker) if isinstance(tensor, paddle.Tensor): - self.value = tensor + self.value = None self.meta = MetaInfo.from_tensor(tensor) elif isinstance(tensor, MetaInfo): self.value = None diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index f1d0baabb..4e6f23cee 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -257,8 +257,14 @@ def bind_ex_arg_with_instr(ex_arg, instr): def modify_vars(instructions, code_options): co_names = code_options['co_names'] co_varnames = code_options['co_varnames'] + co_consts = code_options['co_consts'] for instrs in instructions: - if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST': + if instrs.opname in [ + 'LOAD_FAST', + 'STORE_FAST', + 'DELETE_FAST', + 'STORE_FAST', + ]: assert ( instrs.argval in co_varnames ), f"`{instrs.argval}` not in {co_varnames}" @@ -268,6 +274,11 @@ def modify_vars(instructions, code_options): instrs.argval in co_names ), f"`{instrs.argval}` not in {co_varnames}" instrs.arg = co_names.index(instrs.argval) + elif instrs.opname == 'LOAD_CONST': + assert ( + instrs.argval in co_consts + ), f"`{instrs.argval}` not in {co_consts}" + instrs.arg = co_consts.index(instrs.argval) '''