Skip to content

Commit 0c6b52c

Browse files
FindHaofacebook-github-bot
authored andcommitted
Support Triton repo's triton_kernels custom types via lazy import; add test for single launch and Tensor type (#101)
Summary: - Add first-class support for custom data types provided by the Triton repository's `triton_kernels` module (e.g., `Tensor`, `Storage`, `StridedLayout`). - Avoid top-level imports for optional `triton_kernels.tensor` by probing availability and importing lazily at use sites. - Update CUDA test to emit a single launch event to logs and verify: - Exactly one `.ndjson` log file is produced in the temp logs directory - Exactly one `launch` event exists in that file - `extracted_args["X"].type == "triton_kernels.tensor.Tensor"` ## Motivation - The custom types (`Tensor`, `Storage`, `StridedLayout`) are defined in the Triton repo under `triton_kernels` and may not be installed by default. Importing them at module import time causes hard failures for users without the extra. - We still want to reliably validate parsing and logging by asserting that the launch event is recorded and carries the expected structured argument metadata. ## Changes ### tritonparse/reproducer/utils.py - Add top-level probe (non-import) for optional dependency: - `TRITON_KERNELS_CUSTOM_TYPES = importlib.util.find_spec("triton_kernels.tensor") is not None` - Add lazy loader with caching: - `_get_triton_tensor_types()` loads `Tensor`, `Storage`, `StridedLayout` on first use and is decorated with `functools.lru_cache(maxsize=1)` - Update `_create_arg_from_info` branches to: - Guard with `if not TRITON_KERNELS_CUSTOM_TYPES: raise RuntimeError(...)` - Import via the cached lazy loader at the point of use - Preserve existing tensor dtype handling for random data generation ### tests/test_tritonparse.py - Modify `test_triton_kernels_Tensor` to: - Import `from tritonparse.reproducer import utils as reproducer_utils` - Initialize `tritonparse.structured_logging` with a temp logs directory - Run `_topk_forward` once to generate a single `launch` event - Read the only `.ndjson` file under the temp logs directory and assert: - Exactly one file - Exactly one `launch` event - `extracted_args["X"].type == "triton_kernels.tensor.Tensor"` - Synchronize CUDA and clean up the temp directory (unless `TEST_KEEP_OUTPUT=1`) ## Backward compatibility - When the optional package is absent, constructing the custom types now raises a clear `RuntimeError` instructing that the dependency is missing. Callers can catch this or ensure the extra is installed. - Code paths that do not touch the custom types are unaffected. ## Testing - Unit test updated: `TestTritonparseCUDA.test_triton_kernels_Tensor` - How to run (CUDA required): - `python -m unittest tests.test_tritonparse -v -k test_triton_kernels_Tensor` - Expected: test passes, asserting single launch event and correct `X` type in `extracted_args`. Pull Request resolved: #101 Reviewed By: sfzhu93 Differential Revision: D82361016 Pulled By: FindHao fbshipit-source-id: 83ccfa40a7f6a21aa3ba377ba18b3cb5390d9116
1 parent 3954565 commit 0c6b52c

File tree

8 files changed

+890
-17
lines changed

8 files changed

+890
-17
lines changed

.ci/install-triton-kernels.sh

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/bin/bash
2+
3+
# Install Triton kernels from triton-lang/triton/python/triton_kernels
4+
5+
set -e
6+
7+
echo "🚀 Installing Triton kernels from triton-lang/triton/python/triton_kernels..."
8+
START_TIME=$(date +%s)
9+
10+
# Function to show elapsed time
11+
show_elapsed() {
12+
CURRENT_TIME=$(date +%s)
13+
ELAPSED=$((CURRENT_TIME - START_TIME))
14+
echo "⏱️ Elapsed time: ${ELAPSED}s"
15+
}
16+
17+
# Set Triton version/commit for cache consistency
18+
TRITON_COMMIT=${TRITON_COMMIT:-"main"}
19+
echo "🎯 Target Triton commit/branch: $TRITON_COMMIT"
20+
TRITON_SOURCE_DIR="/tmp/triton"
21+
22+
# Ensure we're in the conda environment
23+
if [ -z "$CONDA_ENV" ]; then
24+
echo "ERROR: CONDA_ENV is not set"
25+
exit 1
26+
fi
27+
28+
# Activate conda environment
29+
source /opt/miniconda3/etc/profile.d/conda.sh
30+
conda activate "$CONDA_ENV"
31+
32+
# Ensure TRITON_SOURCE_DIR contains Triton source; otherwise, clone it
33+
echo "🔧 Ensuring Triton source exists at $TRITON_SOURCE_DIR..."
34+
35+
if [ -d "$TRITON_SOURCE_DIR/.git" ]; then
36+
REMOTE_URL=$(git -C "$TRITON_SOURCE_DIR" remote get-url origin 2>/dev/null || echo "")
37+
if [[ "$REMOTE_URL" == *"triton-lang/triton"* ]]; then
38+
echo "✅ Found existing Triton repository: $REMOTE_URL"
39+
else
40+
echo "⚠️ Existing directory is not triton-lang/triton (origin: $REMOTE_URL). Re-cloning..."
41+
rm -rf "$TRITON_SOURCE_DIR"
42+
fi
43+
fi
44+
45+
if [ ! -d "$TRITON_SOURCE_DIR/.git" ]; then
46+
echo "Cloning Triton repository..."
47+
if ! git clone https://github.com/triton-lang/triton.git "$TRITON_SOURCE_DIR"; then
48+
echo "❌ ERROR: Failed to clone Triton repository"
49+
echo "This might be due to network issues or GitHub rate limiting"
50+
exit 1
51+
fi
52+
fi
53+
54+
echo "Checking out Triton commit/branch: $TRITON_COMMIT"
55+
if ! git -C "$TRITON_SOURCE_DIR" checkout "$TRITON_COMMIT"; then
56+
echo "❌ ERROR: Failed to checkout $TRITON_COMMIT"
57+
exit 1
58+
fi
59+
60+
# Install triton_kernels in editable mode
61+
KERNELS_DIR="$TRITON_SOURCE_DIR/python/triton_kernels"
62+
if [ ! -d "$KERNELS_DIR" ]; then
63+
echo "❌ ERROR: triton_kernels directory not found at $KERNELS_DIR"
64+
exit 1
65+
fi
66+
67+
echo "📦 Installing triton_kernels from $KERNELS_DIR (editable)..."
68+
pip install -e "$KERNELS_DIR"
69+
show_elapsed
70+
71+
# Verify installation with a simple import
72+
echo "🔎 Verifying triton_kernels installation..."
73+
set +e
74+
KERNELS_IMPORT_OUTPUT=$(python -c "import triton_kernels; import os; print('triton_kernels OK'); print(getattr(triton_kernels, '__file__', 'no_file'))" 2>&1)
75+
KERNELS_IMPORT_EXITCODE=$?
76+
set -e
77+
78+
echo "Import exit code: $KERNELS_IMPORT_EXITCODE"
79+
echo "Import output: $KERNELS_IMPORT_OUTPUT"
80+
81+
if [ $KERNELS_IMPORT_EXITCODE -ne 0 ]; then
82+
echo "❌ ERROR: Failed to import triton_kernels"
83+
exit 1
84+
fi
85+
86+
echo "✅ triton_kernels installation verified"
87+
show_elapsed

.github/workflows/test.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ jobs:
142142
TRITON_COMMIT: ${{ steps.triton-commit.outputs.commit }}
143143
run: |
144144
bash .ci/install-triton.sh
145-
145+
bash .ci/install-triton-kernels.sh
146146
- name: Install project dependencies
147147
env:
148148
CONDA_ENV: tritonparse
@@ -192,6 +192,12 @@ jobs:
192192
run: |
193193
bash .ci/setup.sh
194194
195+
- name: Install Triton kernels
196+
env:
197+
CONDA_ENV: tritonparse-pip
198+
run: |
199+
bash .ci/install-triton-kernels.sh
200+
195201
- name: Install project dependencies
196202
env:
197203
CONDA_ENV: tritonparse-pip

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ parsed_output/
5454
*.ndjson.gz
5555
*.json
5656
!tests/example_output/
57+
!tests/example_output/repro/**
5758
!tests/example_output/logs/**
5859
!tests/example_output/parsed_output/**
5960
!tests/example_output/parsed_output_complex/**

0 commit comments

Comments
 (0)