Skip to content

Commit de886f9

Browse files
committed
Install tpu requirements by default in pypi
1 parent c92d9d9 commit de886f9

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

build_hooks.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Custom build hooks for PyPI."""
16+
17+
import os
18+
from hatchling.builders.hooks.plugin.interface import BuildHookInterface
19+
20+
TPU_REQUIREMENTS_PATH = "dependencies/requirements/generated_requirements/tpu-requirements.txt"
21+
22+
23+
def get_tpu_dependencies():
24+
"""Reads the TPU requirements file and returns a list of dependencies."""
25+
if not os.path.exists(TPU_REQUIREMENTS_PATH):
26+
print(f"Warning: TPU requirements file not found at {TPU_REQUIREMENTS_PATH}. Skipping dependency injection.")
27+
return []
28+
29+
with open(TPU_REQUIREMENTS_PATH, "r") as f: # pylint: disable=unspecified-encoding
30+
# Filter out comments and empty lines
31+
deps = [line.strip() for line in f if line.strip() and not line.strip().startswith("#")]
32+
return deps
33+
34+
35+
class CustomBuildHook(BuildHookInterface):
36+
"""A custom hook to inject TPU dependencies into the core wheel dependencies."""
37+
38+
def initialize(self, version, build_data): # pylint: disable=unused-argument
39+
tpu_deps = get_tpu_dependencies()
40+
build_data["dependencies"] = tpu_deps
41+
print(f"Successfully injected {len(tpu_deps)} TPU dependencies into the wheel's core requirements.")

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,8 @@ allow-direct-references = true
3939
[tool.hatch.build.targets.wheel]
4040
packages = ["src/MaxText", "src/install_maxtext_extra_deps"]
4141

42+
[tool.hatch.build.targets.wheel.hooks.custom]
43+
path = "build_hooks.py"
44+
4245
[project.scripts]
4346
install_maxtext_github_deps = "install_maxtext_extra_deps.install_github_deps:main"

0 commit comments

Comments
 (0)