PyTorch Backend
We will use this example project to show how to make AI inference with a PyTorch model in WasmEdge and Rust.
Prerequisite
Besides the regular WasmEdge and Rust requirements, please make sure that you have the Wasi-NN plugin with PyTorch installed.
Quick start
Because the example already includes a compiled WASM file from the Rust code, we could use WasmEdge CLI to execute the example directly. First, git clone the WasmEdge-WASINN-examples
repo.
git clone https://github.com/second-state/WasmEdge-WASINN-examples.git
cd WasmEdge-WASINN-examples/pytorch-mobilenet-image/
Run the inference application in WasmEdge.
wasmedge --dir .:. wasmedge-wasinn-example-mobilenet-image.wasm mobilenet.pt input.jpg
If everything goes well, you should have the terminal output:
Read torchscript binaries, size in bytes: 14376924
Loaded graph into wasi-nn with ID: 0
Created wasi-nn execution context with ID: 0
Read input tensor, size in bytes: 602112
Executed graph inference
1.) [954](20.6681)banana
2.) [940](12.1483)spaghetti squash
3.) [951](11.5748)lemon
4.) [950](10.4899)orange
5.) [953](9.4834)pineapple, ananas
Build and run
Let's build the wasm file from the rust source code. First, git clone the WasmEdge-WASINN-examples
repo.
git clone https://github.com/second-state/WasmEdge-WASINN-examples.git
cd WasmEdge-WASINN-examples/pytorch-mobilenet-image/rust
Second, use cargo
to build the example project.
cargo build --target wasm32-wasi --release
The output WASM file is target/wasm32-wasi/release/wasmedge-wasinn-example-mobilenet-image.wasm
. Next, use WasmEdge to load the PyTorch model and then use it to classify objects in your image.
wasmedge --dir .:. wasmedge-wasinn-example-mobilenet-image.wasm mobilenet.pt input.jpg
You can replace input.jpg
with your image file.
Improve performance
You can make the inference program run faster by AOT compiling the wasm
file first.
wasmedge compile wasmedge-wasinn-example-mobilenet.wasm out.wasm
wasmedge --dir .:. out.wasm mobilenet.pt input.jpg
Understand the code
The main.rs is the complete example Rust source. First, read the image file and PyTorch model file names from the command line.
let args: Vec<String> = env::args().collect();
let model_bin_name: &str = &args[1]; // File name for the PyTorch model
let image_name: &str = &args[2]; // File name for the input image
We use a helper function called image_to_tensor()
to convert the input image into tensor data (the tensor type is F32
). Now we can load the model, feed the tensor array from the image to the model, and get the inference output tensor array.
// load model
let graph = wasi_nn::GraphBuilder::new(
wasi_nn::GraphEncoding::Pytorch,
wasi_nn::ExecutionTarget::CPU,
).build_from_files([model_bin_name]).unwrap();
let mut context = graph.init_execution_context().unwrap();
// Load a tensor that precisely matches the graph input tensor
let tensor_data = image_to_tensor(image_name.to_string(), 224, 224);
context.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &tensor_data).unwrap();
// Execute the inference.
context.compute().unwrap();
// Retrieve the output.
let mut output_buffer = vec![0f32; 1000];
context.get_output(0, &mut output_buffer).unwrap();
In the above code, wasi_nn::GraphEncoding::Pytorch
means using the PyTorch backend, and wasi_nn::ExecutionTarget::CPU
means running the computation on the CPU. Finally, we sort the output and then print the top-5 classification results.
let results = sort_results(&output_buffer);
for i in 0..5 {
println!(
" {}.) [{}]({:.4}){}",
i + 1,
results[i].0,
results[i].1,
imagenet_classes::IMAGENET_CLASSES[results[i].0]
);
}