Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 88 additions & 173 deletions graph_net/paddle/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
import numpy as np
import random
import platform
import traceback

from graph_net.paddle import utils
from graph_net import path_utils
from graph_net import test_compiler_utils
from graph_net.benchmark_result import BenchmarkResult


Expand Down Expand Up @@ -49,21 +52,9 @@ def get_input_dict(args):
params = inputs_params["weight_info"]
inputs = inputs_params["input_info"]

param_dtypes = set()
for name, info in params.items():
dtype = str(info["info"]["dtype"])
if dtype not in param_dtypes:
param_dtypes.add(dtype)

input_dtypes = set()
for name, info in inputs.items():
dtype = str(info["info"]["dtype"])
if dtype not in input_dtypes:
input_dtypes.add(dtype)

params.update(inputs)
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
return state_dict, list(input_dtypes), list(param_dtypes)
return state_dict


def get_input_spec(args):
Expand All @@ -89,62 +80,6 @@ def get_compiled_model(args, model):
return compiled_model


def regular_item(item):
assert isinstance(item, paddle.Tensor)
if item.dtype not in [paddle.float32, paddle.float64]:
item = item.astype("float32")
return item


def count_number_of_ops(args, model, eager_mode):
if eager_mode:
static_model = paddle.jit.to_static(
model,
input_spec=get_input_spec(args),
full_graph=True,
backend=None,
)
static_model.eval()
program = static_model.forward.concrete_program.main_program
else:
program = model.forward.concrete_program.main_program
print(program)

num_ops = 0
for block in program.blocks:
for op in block.ops:
if op.name() != "pd_op.data" and not op.name().startswith("builtin."):
num_ops += 1
print(f"Totally {num_ops} ops.")
print("")
return num_ops


@dataclass
class DurationBox:
value: int


@contextmanager
def naive_timer(duration_box, synchronizer_func):
synchronizer_func()
start = time.time()
yield
synchronizer_func()
end = time.time()
duration_box.value = (end - start) * 1000 # Store in milliseconds


def get_timing_stats(elapsed_times):
stats = {
"mean": float(f"{np.mean(elapsed_times):.6g}"),
"std": float(f"{np.std(elapsed_times):.6g}"),
"min": float(f"{np.min(elapsed_times):.6g}"),
"max": float(f"{np.max(elapsed_times):.6g}"),
}
return stats


def measure_performance(model_call, args, synchronizer_func):
stats = {}

Expand All @@ -168,8 +103,8 @@ def measure_performance(model_call, args, synchronizer_func):

for i in range(args.trials):
# End-to-end timing (naive_timer)
duration_box = DurationBox(-1)
with naive_timer(duration_box, synchronizer_func):
duration_box = test_compiler_utils.DurationBox(-1)
with test_compiler_utils.naive_timer(duration_box, synchronizer_func):
# GPU-only timing (CUDA Events)
start_event = paddle.device.Event(enable_timing=True)
end_event = paddle.device.Event(enable_timing=True)
Expand All @@ -182,11 +117,11 @@ def measure_performance(model_call, args, synchronizer_func):
e2e_times.append(duration_box.value)
gpu_times.append(gpu_time_ms)
print(
f"Trial {i + 1}: e2e={duration_box.value:.4f} ms, gpu={gpu_time_ms:.5g} ms"
f"Trial {i + 1}: e2e={duration_box.value:.5f} ms, gpu={gpu_time_ms:.5f} ms"
)

stats["e2e"] = get_timing_stats(e2e_times)
stats["gpu"] = get_timing_stats(gpu_times)
stats["e2e"] = test_compiler_utils.get_timing_stats(e2e_times)
stats["gpu"] = test_compiler_utils.get_timing_stats(gpu_times)
else: # CPU or other devices
hardware_name = platform.processor()
print(
Expand All @@ -195,12 +130,12 @@ def measure_performance(model_call, args, synchronizer_func):

e2e_times = []
for i in range(args.trials):
duration_box = DurationBox(-1)
with naive_timer(duration_box, compiler.synchronize):
duration_box = test_compiler_utils.DurationBox(-1)
with test_compiler_utils.naive_timer(duration_box, synchronizer_func):
outs = model_call()
print(f"Trial {i + 1}: e2e={duration_box.value:.4f} ms")
e2e_times.append(duration_box.value)
stats["e2e"] = get_timing_stats(e2e_times)
stats["e2e"] = test_compiler_utils.get_timing_stats(e2e_times)

return outs, stats

Expand All @@ -227,70 +162,58 @@ def init_benchmark_result(args):
return result_data


def test_single_model(args):
synchronizer_func = get_synchronizer_func(args)
input_dict, input_dtypes, param_dtypes = get_input_dict(args)
model = get_model(args)
model.eval()

# Collect model information
num_eager_ops = count_number_of_ops(args, model, eager_mode=True)

# Initialize benchmark result
result_data = init_benchmark_result(args)
result_data.update_model_info(num_eager_ops, input_dtypes, param_dtypes)

# Run on eager mode
expected_out, eager_time_stats = measure_performance(
lambda: model(**input_dict), args, synchronizer_func
)

# Run on compiling mode
compiled_model = get_compiled_model(args, model)
compiled_out, compiled_time_stats = measure_performance(
lambda: compiled_model(**input_dict), args, synchronizer_func
)

def check_outputs(args, expected_out, compiled_out):
if isinstance(expected_out, paddle.Tensor):
expected_out = [expected_out]
if isinstance(compiled_out, paddle.Tensor):
compiled_out = [compiled_out]
if isinstance(expected_out, list) or isinstance(expected_out, tuple):
output_dtypes = []
for a, b in zip(expected_out, compiled_out):
if (a is None and b is not None) or (a is not None and b is None):
raise ValueError("Both expected_out and compiled_out must be not None.")
if a is not None and b is not None:
assert (
a.dtype == b.dtype
), f"expected_out's dtype ({a.dtype}) is not the same as compiled_out's dtype {b.dtype}."
output_dtypes.append(str(a.dtype))
result_data.update_corrrectness("num_outpus", len(output_dtypes))
result_data.update_corrrectness("output_dtyps", output_dtypes)

# Remove all None in outputs
expected_out = [x for x in expected_out if x is not None]
compiled_out = [x for x in compiled_out if x is not None]
expected_out = [
regular_item(item)
for item in expected_out
if item is not None and np.array(item).size != 0
]
compiled_out = [
regular_item(item)
for item in compiled_out
if item is not None and np.array(item).size != 0
]
else:
raise ValueError("Illegal return value.")

eager_output_dtypes = [None] * len(expected_out)
for i, tensor in enumerate(expected_out):
if tensor is not None:
eager_output_dtypes[i] = str(tensor.dtype)

compiled_output_dtypes = [None] * len(compiled_out)
for i, tensor in enumerate(compiled_out):
if tensor is not None:
compiled_output_dtypes[i] = str(tensor.dtype)

is_output_consistent = len(expected_out) == len(compiled_out)
for a, b in zip(expected_out, compiled_out):
if (a is None and b is not None) or (a is not None and b is None):
is_output_consistent = False
if a is not None and b is not None and a.dtype != b.dtype:
is_output_consistent = False

def regular_outputs(origin_outputs):
outputs = []
for item in origin_outputs:
if (
item is not None
and isinstance(item, paddle.Tensor)
and item.dtype not in [paddle.float32, paddle.float64]
):
item = item.astype("float32")
outputs.append(item)
return outputs

expected_out = regular_outputs(expected_out)
compiled_out = regular_outputs(compiled_out)

def print_cmp(key, func, **kwargs):
cmp_ret = func(expected_out, compiled_out, **kwargs)
result_data.update_corrrectness(key, cmp_ret)
try:
cmp_ret = func(expected_out, compiled_out, **kwargs)
except Exception as e:
cmp_ret = f"{key} failed: {str(e)}\n{traceback.format_exc()}"
print(
f"{args.log_prompt} {key} model_path:{args.model_path} {cmp_ret}",
file=sys.stderr,
)

print(
f"{args.log_prompt} output_dtypes model_path:{args.model_path} eager:{eager_output_dtypes} compiled:{compiled_output_dtypes}",
file=sys.stderr,
)
print_cmp("cmp.equal", get_cmp_equal)
print_cmp("cmp.all_close_atol8_rtol8", get_cmp_all_close, atol=1e-8, rtol=1e-8)
print_cmp("cmp.all_close_atol8_rtol5", get_cmp_all_close, atol=1e-8, rtol=1e-5)
Expand All @@ -305,29 +228,42 @@ def print_cmp(key, func, **kwargs):
print_cmp("cmp.diff_count_atol3_rtol2", get_cmp_diff_count, atol=1e-3, rtol=1e-2)
print_cmp("cmp.diff_count_atol2_rtol1", get_cmp_diff_count, atol=1e-2, rtol=1e-1)

print(
f"{args.log_prompt} information model_path:{args.model_path} {num_eager_ops} ops, param_dtypes:{param_dtypes}, input_dtypes:{input_dtypes}",
file=sys.stderr,
)

result_data.update_performance(eager_time_stats, compiled_time_stats)
duration_log = (
f"{args.log_prompt} [Duration] "
f"eager_e2e:{result_data.eager_e2e_time_ms:.4f} ms compiled_e2e:{result_data.compiled_e2e_time_ms:.4f} ms"
)
speedup_log = (
f"{args.log_prompt} [Speedup] " f"e2e_speedup:{result_data.e2e_speedup:.4f}"
)
def test_single_model(args):
synchronizer_func = get_synchronizer_func(args)
input_dict = get_input_dict(args)
model = get_model(args)
model.eval()

if "cuda" in args.device:
duration_log += f" eager_gpu:{result_data.eager_gpu_time_ms:.4f} ms compiled_gpu:{result_data.compiled_gpu_time_ms:.4f} ms"
speedup_log += f" gpu_speedup:{result_data.gpu_speedup:.4f}"
# Run on eager mode
running_eager_success = False
try:
print("Run model in eager mode.")
expected_out, eager_time_stats = measure_performance(
lambda: model(**input_dict), args, synchronizer_func
)
running_eager_success = True
except Exception as e:
print(f"Run model in eager mode failed: {str(e)}\n{traceback.format_exc()}")

# Run on compiling mode
running_compiled_success = False
try:
print("Run model in compiled mode.")
compiled_model = get_compiled_model(args, model)
compiled_out, compiled_time_stats = measure_performance(
lambda: compiled_model(**input_dict), args, synchronizer_func
)
running_compiled_success = True
except Exception as e:
print(f"Run model in compiled mode failed: {str(e)}\n{traceback.format_exc()}")

print(duration_log, file=sys.stderr)
print(speedup_log, file=sys.stderr)
if running_eager_success and running_compiled_success:
check_outputs(args, expected_out, compiled_out)

if args.output_dir:
result_data.write_to_json(args.output_dir)
test_compiler_utils.print_times_and_speedup(
args, eager_time_stats, compiled_time_stats
)


def get_cmp_equal(expected_out, compiled_out):
Expand Down Expand Up @@ -365,7 +301,7 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):


def test_multi_models(args):
for model_path in get_recursively_model_path(args.model_path):
for model_path in path_utils.get_recursively_model_path(args.model_path):
cmd = "".join(
[
sys.executable,
Expand All @@ -383,27 +319,6 @@ def test_multi_models(args):
assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"


def get_recursively_model_path(root_dir):
for sub_dir in get_immediate_subdirectory_paths(root_dir):
if is_single_model_dir(sub_dir):
yield sub_dir
else:
yield from get_recursively_model_path(sub_dir)


def get_immediate_subdirectory_paths(parent_dir):
return [
sub_dir
for name in os.listdir(parent_dir)
for sub_dir in [os.path.join(parent_dir, name)]
if os.path.isdir(sub_dir)
]


def is_single_model_dir(model_dir):
return os.path.isfile(f"{model_dir}/graph_net.json")


def main(args):
assert os.path.isdir(args.model_path)
assert args.compiler == "cinn"
Expand All @@ -413,7 +328,7 @@ def main(args):
random.seed(random_seed)
np.random.seed(random_seed)

if is_single_model_dir(args.model_path):
if path_utils.is_single_model_dir(args.model_path):
test_single_model(args)
else:
test_multi_models(args)
Expand Down
10 changes: 5 additions & 5 deletions graph_net/paddle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def load_converted_list_from_text(file_path):
return [*weight_info, *input_info]


def ConvertToValidNumber(data_type, value):
def convert_to_valid_number(data_type, value):
if value is not None and data_type in [
paddle.float32,
paddle.float16,
Expand Down Expand Up @@ -160,10 +160,10 @@ def convert_meta_classes_to_tensors(file_path):
"shape": attrs.get("shape", []),
"dtype": data_type,
"device": attrs.get("device", "gpu"),
"mean": ConvertToValidNumber(data_type, attrs.get("mean", None)),
"std": ConvertToValidNumber(data_type, attrs.get("std", None)),
"min_val": ConvertToValidNumber(data_type, attrs.get("min_val", 0)),
"max_val": ConvertToValidNumber(data_type, attrs.get("max_val", 2)),
"mean": convert_to_valid_number(data_type, attrs.get("mean", None)),
"std": convert_to_valid_number(data_type, attrs.get("std", None)),
"min_val": convert_to_valid_number(data_type, attrs.get("min_val", 0)),
"max_val": convert_to_valid_number(data_type, attrs.get("max_val", 2)),
},
"data": data_value,
"name": attrs.get("name"),
Expand Down
Loading