Skip to content

Commit fe324d3

Browse files
PawelSwider2000Pawel Swider
andauthored
Modify files in install_xpu_headers only if they changes (#2138)
Observed that recompilations are triggered by updating files by install_xpu_headers.py script. Turns out that script does not change the files in any way but rewriting the same content into files updating their timestamp causing multiple dependent files to recompile. This PR makes sure that `install_xpu_headers.py` changes or creates files only when content should change. This allow to speedup recompilations several times, by my experience from few minutes to few seconds. This fixes: #2093 --------- Co-authored-by: Pawel Swider <[email protected]>
1 parent cae6ba3 commit fe324d3

File tree

1 file changed

+76
-27
lines changed

1 file changed

+76
-27
lines changed

tools/codegen/install_xpu_headers.py

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55
from pathlib import Path
66

7+
VERBOSE = False
78

89
parser = argparse.ArgumentParser(description="Utils for append ops headers")
910
parser.add_argument(
@@ -20,34 +21,53 @@ def append_xpu_function_header(src, dst):
2021
r"""
2122
Cleans trailing empty lines from the destination file, then appends #include
2223
lines from the source file that match `#include <ATen/ops/...` to the destination.
24+
Only modifies file if changes are actually needed.
2325
"""
2426
if args.dry_run:
2527
return
2628

27-
# Read source file and match header lines
28-
with open(src, encoding="utf-8") as fr:
29-
src_text = fr.read()
29+
try:
30+
with open(src, encoding="utf-8") as fr:
31+
src_text = fr.read()
32+
except OSError as e:
33+
if VERBOSE:
34+
print(f"Warning: Could not read source file {src}: {e}")
35+
return
36+
3037
pattern = r"^#include <ATen/ops/.*>\s*\r?\n"
3138
matches = re.findall(pattern, src_text, re.MULTILINE)
3239
if not matches:
3340
return
3441

35-
with open(dst, "r+", encoding="utf-8") as f:
36-
dst_lines = f.readlines()
37-
dst_text = "".join(dst_lines)
38-
missing_headers = [match for match in matches if match not in dst_text]
39-
if not missing_headers:
40-
return
42+
try:
43+
with open(dst, encoding="utf-8") as f:
44+
dst_lines = f.readlines()
45+
dst_text = "".join(dst_lines)
46+
except OSError as e:
47+
if VERBOSE:
48+
print(f"Warning: Could not read destination file {dst}: {e}")
49+
return
50+
51+
missing_headers = [match for match in matches if match not in dst_text]
52+
if not missing_headers:
53+
return
54+
55+
new_dst_lines = dst_lines.copy()
56+
57+
while new_dst_lines and not new_dst_lines[-1].strip():
58+
new_dst_lines.pop()
59+
new_dst_lines.extend(missing_headers)
4160

42-
# Remove trailing empty lines from dst_lines
43-
while dst_lines and not dst_lines[-1].strip():
44-
dst_lines.pop()
61+
new_content = "".join(new_dst_lines)
62+
old_content = "".join(dst_lines)
4563

46-
f.seek(0)
47-
f.truncate()
48-
f.writelines(dst_lines)
49-
# Append missing headers to the end of the file
50-
f.writelines(missing_headers)
64+
if new_content != old_content:
65+
try:
66+
with open(dst, "w", encoding="utf-8") as f:
67+
f.writelines(new_dst_lines)
68+
except OSError as e:
69+
if VERBOSE:
70+
print(f"Error: Could not write to {dst}: {e}")
5171

5272

5373
def parse_ops_headers(src):
@@ -78,18 +98,36 @@ def classify_ops_headers(src_dir, dst_dir):
7898

7999
def generate_xpu_ops_headers_cmake(src_dir, dst_dir, xpu_ops_headers):
80100
r"""
81-
Generate XPU ops headers xpu_ops_generated_headers.cmake
101+
Generate XPU ops headers xpu_ops_generated_headers.cmake only if content changes
82102
"""
83-
with open(os.path.join(src_dir, "xpu_ops_generated_headers.cmake"), "w", encoding="utf-8") as fw:
84-
fw.write("set(xpu_ops_generated_headers\n")
85-
for header in xpu_ops_headers:
86-
fw.write(f' "{Path(os.path.join(dst_dir, header)).as_posix()}"\n')
87-
fw.write(")\n")
103+
output_file = os.path.join(src_dir, "xpu_ops_generated_headers.cmake")
104+
105+
# Generate new content
106+
new_content = "set(xpu_ops_generated_headers\n"
107+
for header in xpu_ops_headers:
108+
new_content += f' "{Path(os.path.join(dst_dir, header)).as_posix()}"\n'
109+
new_content += ")\n"
110+
111+
# Check if file exists and has same content
112+
should_write = True
113+
if os.path.exists(output_file):
114+
try:
115+
with open(output_file, encoding="utf-8") as f:
116+
existing_content = f.read()
117+
should_write = existing_content != new_content
118+
except OSError:
119+
# If we can't read the file, write it anyway
120+
should_write = True
121+
122+
if should_write:
123+
with open(output_file, "w", encoding="utf-8") as fw:
124+
fw.write(new_content)
88125

89126

90127
def append_xpu_ops_headers(src_dir, dst_dir, common_headers, xpu_ops_headers):
91128
r"""
92129
For XPU-specific ops headers, copy them to destination build and append XPU declarations to common headers.
130+
Copies and appends are done only if leading to file changes to prevent unnecessary recompilations.
93131
"""
94132
if args.dry_run:
95133
return
@@ -99,7 +137,16 @@ def append_xpu_ops_headers(src_dir, dst_dir, common_headers, xpu_ops_headers):
99137
# assert "xpu" in f, f"Error: The function signature or namespace in '{f}' is incorrect. Expected 'xpu' to be present."
100138
src = os.path.join(src_dir, f)
101139
dst = os.path.join(dst_dir, f)
102-
shutil.copy(src, dst)
140+
# Only copy if src and dst differ or dst does not exist
141+
should_copy = True
142+
if os.path.exists(dst):
143+
try:
144+
with open(src, "rb") as fsrc, open(dst, "rb") as fdst:
145+
should_copy = fsrc.read() != fdst.read()
146+
except OSError:
147+
should_copy = True
148+
if should_copy:
149+
shutil.copy(src, dst)
103150

104151
for f in common_headers:
105152
src = os.path.join(src_dir, f)
@@ -118,6 +165,7 @@ def append_xpu_ops_headers(src_dir, dst_dir, common_headers, xpu_ops_headers):
118165
with open(dst, "r+", encoding="utf-8") as f:
119166
dst_lines = f.readlines()
120167
dst_text = "".join(dst_lines)
168+
old_content = "".join(dst_lines)
121169
missing_declarations = []
122170
insertion_index = None
123171
for index, line in enumerate(dst_lines):
@@ -133,9 +181,10 @@ def append_xpu_ops_headers(src_dir, dst_dir, common_headers, xpu_ops_headers):
133181
break
134182
assert (insertion_index is not None), f"Error: No TORCH_API declaration found in {dst}."
135183

136-
f.seek(0)
137-
f.writelines(dst_lines)
138-
f.truncate()
184+
if old_content != "".join(dst_lines):
185+
f.seek(0)
186+
f.writelines(dst_lines)
187+
f.truncate()
139188

140189

141190
def main():

0 commit comments

Comments
 (0)