diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index eaabfbb..1edf8b3 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -21,7 +21,7 @@ jobs: - name: Build Wasmtime + wasi-nn run: | - rustup target add wasm32-wasi + rustup target add wasm32-wasip1 rustup target add wasm32-unknown-unknown git clone https://github.com/bytecodealliance/wasmtime --branch v16.0.0 --depth 1 --recursive cd wasmtime @@ -52,13 +52,13 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - run: rustup target add wasm32-wasi + - run: rustup target add wasm32-wasip1 - name: Run tests on native architecture working-directory: rust run: cargo test - - name: Build on wasm32-wasi + - name: Build on wasm32-wasip1 working-directory: rust - run: cargo build --target=wasm32-wasi + run: cargo build --target=wasm32-wasip1 - name: Check dry-run publish to crates.io working-directory: rust run: cargo publish --dry-run @@ -68,10 +68,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - run: rustup target add wasm32-wasi - - name: Build for wasm32-wasi + - run: rustup target add wasm32-wasip1 + - name: Build for wasm32-wasip1 working-directory: image2tensor - run: cargo build --target=wasm32-wasi + run: cargo build --target=wasm32-wasip1 - name: Check dry-run publish to crates.io working-directory: image2tensor run: cargo publish --dry-run diff --git a/build.sh b/build.sh index 358dd60..bc968db 100755 --- a/build.sh +++ b/build.sh @@ -24,13 +24,13 @@ else echo "The first argument: $1" FIXTURE=https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet pushd $WASI_NN_DIR/rust/ - cargo build --release --target=wasm32-wasi + cargo build --release --target=wasm32-wasip1 mkdir -p $WASI_NN_DIR/rust/examples/classification-example/build RUST_BUILD_DIR=$(realpath $WASI_NN_DIR/rust/examples/classification-example/build/) cp -rn examples/images $RUST_BUILD_DIR pushd examples/classification-example - cargo build --release --target=wasm32-wasi - cp target/wasm32-wasi/release/wasi-nn-example.wasm $RUST_BUILD_DIR + cargo build --release --target=wasm32-wasip1 + cp target/wasm32-wasip1/release/wasi-nn-example.wasm $RUST_BUILD_DIR pushd build wget --no-clobber --directory-prefix=$RUST_BUILD_DIR $FIXTURE/mobilenet.bin wget --no-clobber --directory-prefix=$RUST_BUILD_DIR $FIXTURE/mobilenet.xml diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 1ed646d..acf1c6f 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -177,7 +177,7 @@ checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" [[package]] name = "wasi-nn" -version = "25.0.2" +version = "0.8.0" dependencies = [ "wit-bindgen", ] diff --git a/rust/Cargo.toml b/rust/Cargo.toml index f9553f6..ab899f9 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wasi-nn" -version = "0.7.0" +version = "0.8.0" authors = ["The Bytecode Alliance Developers"] description = "High-level Rust bindings for wasi-nn" license = "Apache-2.0" diff --git a/rust/examples/classification-example/Cargo.lock b/rust/examples/classification-example/Cargo.lock index 6423cee..5947425 100644 --- a/rust/examples/classification-example/Cargo.lock +++ b/rust/examples/classification-example/Cargo.lock @@ -298,7 +298,7 @@ checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" [[package]] name = "wasi-nn" -version = "25.0.2" +version = "0.8.0" dependencies = [ "wit-bindgen", ] diff --git a/rust/examples/classification-example/src/main.rs b/rust/examples/classification-example/src/main.rs index ba3bb3a..a1f104b 100644 --- a/rust/examples/classification-example/src/main.rs +++ b/rust/examples/classification-example/src/main.rs @@ -37,16 +37,18 @@ fn main() { TensorType::Fp32, &data, ); - exec_context.set_input("data", tensor).unwrap(); - println!("Set input tensor"); - // Execute the inferencing - exec_context.compute().unwrap(); + let output_tensor_vec = exec_context.compute(vec![("data".to_string(), tensor)]).unwrap(); println!("Executed graph inference"); - // Get the inferencing result (bytes) and convert it to f32 - println!("Getting inferencing output"); - let output_data = exec_context.get_output("squeezenet0_flatten0_reshape0").unwrap().data(); + let output_tensor = output_tensor_vec.iter().find_map(|(tensor_name, tensor)| { + if tensor_name == "squeezenet0_flatten0_reshape0" { + Some(tensor) + } else { + None + } + }); + let output_data = output_tensor.expect("No output tensor").data(); println!("Retrieved output data with length: {}", output_data.len()); let output_f32 = bytes_to_f32_vec(output_data); diff --git a/rust/wit/wasi-nn.wit b/rust/wit/wasi-nn.wit index 2ae5015..d8734dd 100644 --- a/rust/wit/wasi-nn.wit +++ b/rust/wit/wasi-nn.wit @@ -1,4 +1,4 @@ -package wasi:nn@0.2.0-rc-2024-08-19; +package wasi:nn@0.2.0-rc-2024-10-28; /// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet) /// capable of performing ML training. WebAssembly programs that want to use a host's ML @@ -110,25 +110,19 @@ interface graph { /// `graph` to input tensors before `compute`-ing an inference: interface inference { use errors.{error}; - use tensor.{tensor, tensor-data}; + use tensor.{tensor}; + + /// Identify a tensor by name; this is necessary to associate tensors to + /// graph inputs and outputs. + type named-tensor = tuple; /// Bind a `graph` to the input and output tensors for an inference. /// /// TODO: this may no longer be necessary in WIT /// (https://github.com/WebAssembly/wasi-nn/issues/43) resource graph-execution-context { - /// Define the inputs to use for inference. - set-input: func(name: string, tensor: tensor) -> result<_, error>; - /// Compute the inference on the given inputs. - /// - /// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this - /// expectation could be removed as a part of - /// https://github.com/WebAssembly/wasi-nn/issues/43. - compute: func() -> result<_, error>; - - /// Extract the outputs after inference. - get-output: func(name: string) -> result; + compute: func(inputs: list) -> result, error>; } } @@ -163,4 +157,4 @@ interface errors { /// Errors can propagated with backend specific status through a string value. data: func() -> string; } -} +} \ No newline at end of file