Skip to content

Commit f362e7e

Browse files
authored
Fix torch/LAPIS integration (#47)
* Create variables to test if torch-mlir/mpact enabled * Get rid of test translations in lapis-translate This just duplicates what mlir-translate already does * Add lower_and_emit_kokkos CAPI/py function Given an MLIR source string for a module, it parses, lowers and emits Kokkos in one shot Works for matadd when called directly from Python * Add lapis-emit CLI utility Does end-to-end lowering and emitting. Takes input as MLIR source, as file or stdin (default=stdin) matadd.py and torchscript_resnet18.py work again when called from Python. Dense inputs and outputs are in device space so they're suitable for benchmarking from C++ driver. * Fixes for MPACT version of spmv - handle casted integers in loop mapping pass (CSR pattern detection) - support memref.subview in emitter README - use https, not ssh URLs for repositories - add torch, torchvision SHAs and instruction link to build from source. * Modify util functions to make params/returns DualView * WIP: updates to emitter - multiple results - making all memref returns of non-extern funcs DualView * More detailed IR dumps from lapis-emit * Add KokkosBackend.compile_mpact function Runs new sparse pipeline and emits C++. Python wrapper still broken but C++ is good. * Minimal emitter fixes to get sparseMHA to compile * Update examples (from Dec review) - make a copy of resnet18 that uses dynamic batch size (resnet18_dynamic.py), since E3SM needs this to work. - Rename torchscript_resnet18.py to resnet18_static.py since it has a compile-time batch dimension (1). - Add MPACT based spmv example * Sync results to host before returning to numpy --------- Signed-off-by: Brian Kelley <[email protected]>
1 parent d3fa8a8 commit f362e7e

40 files changed

+1634
-1608
lines changed

CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,20 @@ include_directories(${STANDALONE_BINARY_DIR}/include)
8787
set(LAPIS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
8888
set(LAPIS_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
8989

90+
if("torch-mlir" IN_LIST LLVM_EXTERNAL_PROJECTS)
91+
message(STATUS "LAPIS has torch-mlir support")
92+
set(LAPIS_HAS_TORCH_MLIR ON)
93+
else()
94+
set(LAPIS_HAS_TORCH_MLIR OFF)
95+
endif()
96+
97+
if("mpact" IN_LIST LLVM_EXTERNAL_PROJECTS)
98+
message(STATUS "LAPIS has MPACT support")
99+
set(LAPIS_HAS_MPACT ON)
100+
else()
101+
set(LAPIS_HAS_MPACT OFF)
102+
endif()
103+
90104
link_directories(${LLVM_BUILD_LIBRARY_DIR})
91105
add_definitions(${LLVM_DEFINITIONS})
92106

README.md

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ are set to the paths of these repositories.
2929

3030
The following commands will clone the correct versions of all the repositories and set these environment variables.
3131
```
32-
git clone git@github.com:MPACT-ORG/mpact-compiler
32+
git clone https://github.com/MPACT-ORG/mpact-compiler
3333
cd mpact-compiler
3434
git checkout 556009cd
3535
git submodule update --init --recursive
3636
export MPACT_SRC=`pwd`
3737
export TORCH_MLIR_SRC="$MPACT_SRC/externals/torch-mlir"
3838
export LLVM_SRC="$TORCH_MLIR_SRC/externals/llvm-project"
3939
cd ..
40-
git clone git@github.com:sandialabs/LAPIS
40+
git clone https://github.com/sandialabs/LAPIS
4141
cd LAPIS
4242
export LAPIS_SRC=`pwd`
4343
cd ..
44-
git clone -b master git@github.com:kokkos/kokkos
44+
git clone -b master https://github.com/kokkos/kokkos
4545
```
4646

4747
Building with ninja is not required but useful as it automatically uses all cores for parallel compilation. Pass ``-Gninja`` to
@@ -92,15 +92,35 @@ cd ..
9292
This recipe builds LAPIS as an external project with LLVM.
9393
torch-mlir and mpact require this recipe, but torch-mlir and mpact are still optional.
9494
mpact requires torch-mlir, however.
95-
**This requires ninja due to an issue in torch-mlir. make will not work.**
95+
**ninja is required due to an issue in torch-mlir. make will not work.**
9696
```
9797
# If enabling torch-mlir, need to install Python dependencies first.
9898
# This can be done inside a python virtual env.
9999
100100
cd $TORCH_MLIR_SRC
101-
pip install -r requirements.txt
102-
pip install -r torchvision-requirements.txt
101+
pip install -r build-requirements.txt
102+
pip install -r test-requirements.txt
103+
```
104+
105+
Then install the torch and torchvision Python packages from source.
106+
It needs to be from source because torch-mlir only works with specific nightly versions
107+
of these packages, but the package repository only stores nightly binaries for a limited amount of time.
103108

109+
A CPU-only build is sufficient for LAPIS.
110+
* [torch repository](https://github.com/pytorch/pytorch), version 995ec16c)
111+
* [torchvision repository](https://github.com/pytorch/vision), version c7ea645b)
112+
[Instructions for building from source can be found here](https://github.com/pytorch/pytorch#from-source)
113+
114+
These Git SHAs correspond to the nightly versions 2.5.0.dev20240909 and 0.20.0.dev20240909 respectively.
115+
*Note for developers:* from installed nightly versions, the exact Git versions can be found with:
116+
```
117+
import torch
118+
import torchvision
119+
print(torch.version.git_version)
120+
print(torchvision.version.git_version)
121+
```
122+
123+
```
104124
cd $WORKSPACE
105125
mkdir build
106126
cd build

examples/MTTKRP.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

examples/fx_resnet18.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
torch.ones(1, 3, 224, 224),
3636
output_type="linalg-on-tensors",
3737
func_name=resnet18.__class__.__name__,
38+
experimental_support_mutation=True
3839
)
39-
backend = KokkosBackend.KokkosBackend()
40-
compiled = backend.compile_sparse(module)
40+
backend = KokkosBackend.KokkosBackend(dump_mlir=True)
41+
compiled = backend.compile(module)
4142
#fx_module = backend.load(compiled)
4243

4344
params = {

examples/matadd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def main():
2222

2323
mlir_module = torchscript.compile(m, (a, b), output_type='linalg-on-tensors')
2424

25-
backend = KokkosBackend(dump_mlir=True)
25+
backend = KokkosBackend.KokkosBackend(dump_mlir=True)
2626
k_backend = backend.compile(mlir_module)
2727

2828
c = k_backend.forward(a, b)
@@ -34,3 +34,4 @@ def main():
3434

3535
if __name__ == "__main__":
3636
main()
37+

examples/matadd_pytaco.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

examples/pytaco_cpp_driver/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#define PYTACO_CPP_DRIVER
1+
#define LAPIS_CPP_DRIVER
2+
23
#include <iostream>
34
#include <fstream>
45
#include "mlir/Dialect/SparseTensor/IR/Enums.h"

examples/resnet18_dynamic.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import sys
7+
from pathlib import Path
8+
9+
import torch
10+
import torchvision.models as models
11+
from torch_mlir import torchscript
12+
from torch_mlir.compiler_utils import TensorPlaceholder
13+
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
14+
from lapis import KokkosBackend
15+
16+
sys.path.append(str(Path(__file__).absolute().parent))
17+
from _example_utils import (
18+
top3_possibilities,
19+
load_and_preprocess_image,
20+
load_labels,
21+
DEFAULT_IMAGE_URL,
22+
)
23+
24+
def predictions(torch_func, kokkos_func, img, labels):
25+
golden_prediction = top3_possibilities(torch_func(img), labels)
26+
print("PyTorch prediction")
27+
print(golden_prediction)
28+
prediction = top3_possibilities(torch.from_numpy(kokkos_func(img.numpy())), labels)
29+
print("LAPIS prediction")
30+
print(prediction)
31+
32+
print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr)
33+
img = load_and_preprocess_image(DEFAULT_IMAGE_URL)
34+
labels = load_labels()
35+
#print("Dumping preprocessed dog image to dog.bin")
36+
#img.numpy().tofile('dog.bin')
37+
38+
resnet18 = models.resnet18(pretrained=True).eval()
39+
#print(help(torchscript.compile))
40+
41+
module = torchscript.compile(
42+
resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors"
43+
)
44+
backend = refbackend.RefBackendLinalgOnTensorsBackend()
45+
compiled = backend.compile(module)
46+
jit_module = backend.load(compiled)
47+
48+
imgPH = TensorPlaceholder([-1, 3, 224, 224], torch.float32)
49+
#imgPH = TensorPlaceholder([1, 3, 224, 224], torch.float32)
50+
51+
kModule = torchscript.compile(resnet18, imgPH, output_type="linalg-on-tensors")
52+
kBackend = KokkosBackend.KokkosBackend(dump_mlir=True)
53+
kCompiledModule = kBackend.compile(kModule)
54+
55+
predictions(resnet18.forward, kCompiledModule.forward, img, labels)
56+

examples/resnet18_static.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import sys
7+
from pathlib import Path
8+
9+
import torch
10+
import torchvision.models as models
11+
from torch_mlir import torchscript
12+
from torch_mlir.compiler_utils import TensorPlaceholder
13+
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
14+
from lapis import KokkosBackend
15+
16+
sys.path.append(str(Path(__file__).absolute().parent))
17+
from _example_utils import (
18+
top3_possibilities,
19+
load_and_preprocess_image,
20+
load_labels,
21+
DEFAULT_IMAGE_URL,
22+
)
23+
24+
def predictions(torch_func, kokkos_func, img, labels):
25+
golden_prediction = top3_possibilities(torch_func(img), labels)
26+
print("PyTorch prediction")
27+
print(golden_prediction)
28+
prediction = top3_possibilities(torch.from_numpy(kokkos_func(img.numpy())), labels)
29+
print("LAPIS prediction")
30+
print(prediction)
31+
32+
print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr)
33+
img = load_and_preprocess_image(DEFAULT_IMAGE_URL)
34+
labels = load_labels()
35+
#print("Dumping preprocessed dog image to dog.bin")
36+
#img.numpy().tofile('dog.bin')
37+
38+
resnet18 = models.resnet18(pretrained=True).eval()
39+
#print(help(torchscript.compile))
40+
41+
module = torchscript.compile(
42+
resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors"
43+
)
44+
backend = refbackend.RefBackendLinalgOnTensorsBackend()
45+
compiled = backend.compile(module)
46+
jit_module = backend.load(compiled)
47+
48+
imgPH = TensorPlaceholder([1, 3, 224, 224], torch.float32)
49+
50+
kModule = torchscript.compile(resnet18, imgPH, output_type="linalg-on-tensors")
51+
kBackend = KokkosBackend.KokkosBackend(dump_mlir=True)
52+
kCompiledModule = kBackend.compile(kModule)
53+
54+
predictions(resnet18.forward, kCompiledModule.forward, img, labels)
55+

examples/sparse_pytaco.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

0 commit comments

Comments
 (0)