Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,20 @@ def __init__(self, frame, **kwargs):
self.input_variables = [] # Store variables required within a function
self.pycode_gen = PyCodeGen(frame, disable_eval_frame=True)
self.side_effects = SideEffects()
self.py_frame = frame
self._global_guarded_variables: OrderedSet[VariableBase] = OrderedSet()
self._print_variables = []
self.build_strategy = kwargs.get('build_strategy', None)

def clear(self):
self.sir_ctx = None
self.inner_out = None
self.input_variables = None
self.pycode_gen = None
self.side_effects = None
self._global_guarded_variables = None
self._print_variables = None
self.build_strategy = None

@cached_property
def _builtins(self):
builtins_ = {}
Expand Down Expand Up @@ -197,13 +206,28 @@ def start_compile_with_name_store(self, ret_vars, to_store_vars):
class VariableLoader:
def __init__(self, index_for_load, pycode_gen):
self._index_for_load = index_for_load
self._save_cnt = {}
self._pycode_gen = pycode_gen

def save(self, var):
cnt = self._save_cnt.get(var.id, 0)
self._save_cnt[var.id] = cnt + 1

def load(self, var):
if isinstance(var, DummyVariable):
var.reconstruct(self._pycode_gen)
return
self._pycode_gen.gen_load_fast(self._index_for_load[var.id])
self.delete(var)

def delete(self, var):
if isinstance(var, DummyVariable):
return
self._save_cnt[var.id] -= 1
if self._save_cnt[var.id] == 0:
self._pycode_gen.gen_delete_fast(
self._index_for_load[var.id]
)

# var_id -> local_name mapping
index_for_load = {}
Expand All @@ -223,9 +247,11 @@ def _log_fn():

log_do(4, _log_fn)

loader = VariableLoader(index_for_load, self.pycode_gen)
for var in to_store_vars[::-1]:
self.pycode_gen.gen_store_fast(index_for_load[var.id])
return VariableLoader(index_for_load, self.pycode_gen)
loader.save(var)
return loader

@event_register("start_compile")
def start_compile(self, *ret_vars: VariableBase):
Expand Down Expand Up @@ -293,6 +319,9 @@ def start_compile(self, *ret_vars: VariableBase):

view_tracker(list(ret_vars), tracker_output_path, format="png")

# delete all the resume_xxx locals in f_locals
self.pycode_gen.gen_delete_resume_locals()

def call_paddle_api(
self,
func: Callable[..., Any],
Expand Down
13 changes: 11 additions & 2 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction | None:
GuardedFunction | None: The translated code object and its guard function, or None if translation fails.
"""
simulator = OpcodeExecutor(frame, **kwargs)
code = simulator._code
try:
log(3, f"OriginCode: {simulator._code}\n")
log_do(3, lambda: dis.dis(simulator._code))
log_do(3, lambda: dis.dis(code))
new_code, guard_fn = simulator.transform()
log(3, f"NewCode: {new_code}\n")
log_do(3, lambda: dis.dis(new_code))
Expand All @@ -255,6 +256,12 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction | None:
return py_codegen.replace_dummy_variable()
except Exception as e:
raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e
finally:
# collect all the cicle references.
# we gc after transformation but not after cache hit.
import gc

gc.collect()


def tos_op_wrapper(fn: Callable):
Expand Down Expand Up @@ -1649,15 +1656,17 @@ def _break_graph_in_call(
self.pop_n(pop_n)
stack_size = len(self._stack) + push_n
resume_fn, _ = self._create_resume_fn(index + 1, stack_size)
passing_arg_num = len(resume_input_name) + stack_size
if resume_fn:
self._graph.pycode_gen.gen_load_object(
resume_fn, resume_fn.__code__.co_name
)
self._graph.pycode_gen.gen_rot_n(stack_size + 1)
for name in resume_input_name:
var_loader.load(self.get_var(name))
self._graph.pycode_gen.gen_build_list(passing_arg_num)
self._graph.pycode_gen.gen_call_function(
argc=resume_fn.__code__.co_argcount,
argc=1,
with_eval_frame=True,
)

Expand Down
75 changes: 67 additions & 8 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from __future__ import annotations

import dis
import os
import re
import sys
import types
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -257,20 +259,33 @@ def gen_resume_fn_at(self, index, stack_size=0):
return None, OrderedSet()
inputs = analysis_inputs(self._instructions, index)
fn_name = ResumeFnNameFactory().next()
stack_arg_str = fn_name + '_stack_{}'
stack_args_list = fn_name + '_list'
header = []
for i in range(stack_size):
header.append(gen_instr('LOAD_FAST', argval=stack_args_list))
header.append(gen_instr('LOAD_CONST', argval=0))
header.append(gen_instr('BINARY_SUBSCR', argval=None))
header.append(gen_instr('LOAD_FAST', argval=stack_args_list))
header.append(gen_instr('LOAD_CONST', argval=0))
header.append(gen_instr('DELETE_SUBSCR', argval=None))

for name in list(inputs):
header.append(gen_instr('LOAD_FAST', argval=stack_args_list))
header.append(gen_instr('LOAD_CONST', argval=0))
header.append(gen_instr('BINARY_SUBSCR', argval=None))
header.append(gen_instr('STORE_FAST', argval=name))

self._instructions = (
[
gen_instr('LOAD_FAST', argval=stack_arg_str.format(i))
for i in range(stack_size)
]
header
+ [gen_instr('JUMP_ABSOLUTE', jump_to=self._instructions[index])]
+ self._instructions
)

self._code_options['co_argcount'] = len(inputs) + stack_size
if 0 not in self._code_options['co_consts']:
self._code_options['co_consts'] += [0]
self._code_options['co_argcount'] = 1
# inputs should be at the front of the co_varnames
self._code_options['co_varnames'] = list(
[stack_arg_str.format(i) for i in range(stack_size)]
[stack_args_list]
+ list(inputs)
+ [
var_name
Expand All @@ -292,6 +307,8 @@ def gen_resume_fn_at(self, index, stack_size=0):
return fn, inputs

def gen_disable_eval_frame(self):
if os.environ.get('CLEAN_CODE', None) is not None:
return
self.gen_load_object(
paddle.fluid.core.set_eval_frame, "paddle_set_eval_frame_fn"
)
Expand All @@ -300,6 +317,8 @@ def gen_disable_eval_frame(self):
self.gen_store_fast("___old_eval_frame")

def gen_enable_eval_frame(self):
if os.environ.get('CLEAN_CODE', None) is not None:
return
self.gen_load_object(
paddle.fluid.core.set_eval_frame, "paddle_set_eval_frame_fn"
)
Expand Down Expand Up @@ -481,6 +500,46 @@ def gen_load_fast(self, name):
idx = self._code_options["co_varnames"].index(name)
self._add_instr("LOAD_FAST", arg=idx, argval=name)

def gen_delete_resume_locals(self):
def dbg_func():
import inspect

print("dbg here.")
frame = inspect.currentframe().f_back
code = inspect.currentframe().f_back.f_code
print("locals = ", frame.f_locals)
print("code is = ", code)
import gc
import sys

gc.collect()
if 'resume_0_stack_0' in frame.f_locals:
print(
"Ref: ", sys.getrefcount(frame.f_locals['resume_0_stack_0'])
)
print(
"Refers: ",
gc.get_referrers(frame.f_locals['resume_0_stack_0']),
)
import sys

sys.xk_args = frame.f_locals['resume_0_stack_0']

# self.gen_dbg_function(dbg_func)
resume_local_pattern = "resume_[0-9]+_stack_[0-9]+"
for name in self._code_options['co_varnames']:
if (
re.match(resume_local_pattern, name)
and name in self._frame.f_locals
):
self.gen_delete_fast(name)

def gen_delete_fast(self, name):
if name not in self._code_options["co_varnames"]:
self._code_options["co_varnames"].append(name)
idx = self._code_options["co_varnames"].index(name)
self._add_instr("DELETE_FAST", arg=idx, argval=name)

def gen_load_deref(self, name):
if name not in self._code_options["co_cellvars"]:
self._code_options["co_cellvars"].append(name)
Expand Down
2 changes: 1 addition & 1 deletion sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def __init__(
):
super().__init__(graph, tracker)
if isinstance(tensor, paddle.Tensor):
self.value = tensor
self.value = None
self.meta = MetaInfo.from_tensor(tensor)
elif isinstance(tensor, MetaInfo):
self.value = None
Expand Down
13 changes: 12 additions & 1 deletion sot/opcode_translator/instruction_utils/instruction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,14 @@ def bind_ex_arg_with_instr(ex_arg, instr):
def modify_vars(instructions, code_options):
co_names = code_options['co_names']
co_varnames = code_options['co_varnames']
co_consts = code_options['co_consts']
for instrs in instructions:
if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST':
if instrs.opname in [
'LOAD_FAST',
'STORE_FAST',
'DELETE_FAST',
'STORE_FAST',
]:
assert (
instrs.argval in co_varnames
), f"`{instrs.argval}` not in {co_varnames}"
Expand All @@ -268,6 +274,11 @@ def modify_vars(instructions, code_options):
instrs.argval in co_names
), f"`{instrs.argval}` not in {co_varnames}"
instrs.arg = co_names.index(instrs.argval)
elif instrs.opname == 'LOAD_CONST':
assert (
instrs.argval in co_consts
), f"`{instrs.argval}` not in {co_consts}"
instrs.arg = co_consts.index(instrs.argval)


'''
Expand Down
3 changes: 0 additions & 3 deletions tests/run_all_paddle_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ disabled_tests=(
# Because range interrupts networking, Paddle.grad cannot be networked as a standalone API.
# CAN BE OPEN AFTER: range is support.
${PADDLE_TEST_BASE}/test_grad.py
${PADDLE_TEST_BASE}/test_ptb_lm.py # There is accuracy problem of the model in SOT
${PADDLE_TEST_BASE}/test_ptb_lm_v2.py # There is accuracy problem of the model in SOT
${PADDLE_TEST_BASE}/test_cycle_gan.py # This test has a precision problem when it reaches the maximum cache size
)

for file in ${PADDLE_TEST_BASE}/*.py; do
Expand Down