Skip to content
26 changes: 20 additions & 6 deletions crates/guest-rust/macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
use syn::parse::{Error, Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{braced, token, LitStr, Token};
use wit_bindgen_core::wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId};
use wit_bindgen_core::AsyncFilterSet;
use wit_bindgen_rust::{Opts, Ownership, WithOption};
use wit_bindgen_rust::{Opts, Ownership, StubsMode, WithOption};

#[proc_macro]
pub fn generate(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand Down Expand Up @@ -106,9 +107,7 @@ impl Parse for Config {
Opt::Skip(list) => opts.skip.extend(list.iter().map(|i| i.value())),
Opt::RuntimePath(path) => opts.runtime_path = Some(path.value()),
Opt::BitflagsPath(path) => opts.bitflags_path = Some(path.value()),
Opt::Stubs => {
opts.stubs = true;
}
Opt::Stubs(mode) => opts.stubs = mode,
Opt::ExportPrefix(prefix) => opts.export_prefix = Some(prefix.value()),
Opt::AdditionalDerives(paths) => {
opts.additional_derive_attributes = paths
Expand Down Expand Up @@ -380,7 +379,7 @@ enum Opt {
Ownership(Ownership),
RuntimePath(syn::LitStr),
BitflagsPath(syn::LitStr),
Stubs,
Stubs(StubsMode),
ExportPrefix(syn::LitStr),
// Parse as paths so we can take the concrete types/macro names rather than raw strings
AdditionalDerives(Vec<syn::Path>),
Expand Down Expand Up @@ -486,7 +485,22 @@ impl Parse for Opt {
Ok(Opt::BitflagsPath(input.parse()?))
} else if l.peek(kw::stubs) {
input.parse::<kw::stubs>()?;
Ok(Opt::Stubs)
input.parse::<Token![:]>()?;
let stubs_mode = input.parse::<syn::Ident>()?;
Ok(Opt::Stubs(match stubs_mode.to_string().as_str() {
"Omit" => StubsMode::Omit,
"Embedded" => StubsMode::Embedded,
"Separate" => StubsMode::Separate,
name => {
return Err(Error::new(
stubs_mode.span(),
format!(
"unrecognized stubs mode: `{name}`; \
expected `Omit`, `Embedded`, or `Separate`"
),
));
}
}))
} else if l.peek(kw::export_prefix) {
input.parse::<kw::export_prefix>()?;
input.parse::<Token![:]>()?;
Expand Down
230 changes: 162 additions & 68 deletions crates/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use heck::*;
use indexmap::{IndexMap, IndexSet};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::{self, Write as _};
use std::hint::unreachable_unchecked;
use std::mem;
use std::str::FromStr;
use std::sync::LazyLock;
use wit_bindgen_core::abi::{Bitcast, WasmType};
use wit_bindgen_core::{
dealias, name_package_module, uwrite, uwriteln, wit_parser::*, AsyncFilterSet, Files,
Expand All @@ -30,6 +32,8 @@ struct RustWasm {
types: Types,
src_preamble: Source,
src: Source,
/// Used when stubs == StubsMode::Separate
stubs_src: Source,
opts: Opts,
import_modules: Vec<(String, Vec<String>)>,
export_modules: Vec<(String, Vec<String>)>,
Expand Down Expand Up @@ -166,10 +170,18 @@ pub struct Opts {
#[cfg_attr(feature = "clap", arg(long, value_name = "NAME"))]
pub skip: Vec<String>,

/// If true, generate stub implementations for any exported functions,
/// Whether to generate stub implementations for any exported functions,
/// interfaces, and/or resources.
#[cfg_attr(feature = "clap", arg(long))]
pub stubs: bool,
///
/// Valid values are:
///
/// - `omit`: Stubs will not be generated.
///
/// - `embedded`: Stubs will be generated in the bindings file.
///
/// - `separate`: Stubs will be generated in a separate _impl file.
#[cfg_attr(feature = "clap", arg(long, default_value_t = StubsMode::Omit))]
pub stubs: StubsMode,

/// Optionally prefix any export names with the specified value.
///
Expand Down Expand Up @@ -421,7 +433,11 @@ impl RustWasm {
uwriteln!(self.src, "#[rustfmt::skip]");
}

self.src.push_str("mod _rt {\n");
if matches!(self.opts.stubs, StubsMode::Separate) {
self.src.push_str("pub(crate) mod _rt {\n");
} else {
self.src.push_str("mod _rt {\n");
}
self.src.push_str("#![allow(dead_code, clippy::all)]\n");
let mut emitted = IndexSet::new();
while !self.rt_module.is_empty() {
Expand Down Expand Up @@ -839,9 +855,18 @@ macro_rules! __export_{world_name}_impl {{
{use_vis} use __export_{world_name}_impl as {export_macro_name};"
);

if self.opts.stubs {
uwriteln!(self.src, "export!(Stub);");
}
match self.opts.stubs {
StubsMode::Embedded => uwriteln!(self.src, "export!(Stub);"),
StubsMode::Separate => {
let name = &resolve.worlds[world_id].name;
let module_name = to_rust_module_raw(name.to_snake_case().as_str());
uwriteln!(
self.stubs_src,
"export!(Stub with_types_in crate::{module_name});"
)
}
StubsMode::Omit => {}
};
}

/// Generates a `#[link_section]` custom section to get smuggled through
Expand Down Expand Up @@ -981,8 +1006,26 @@ impl WorldGenerator for RustWasm {
if !self.opts.skip.is_empty() {
uwriteln!(self.src_preamble, "// * skip: {:?}", self.opts.skip);
}
if self.opts.stubs {
uwriteln!(self.src_preamble, "// * stubs");
if !matches!(self.opts.stubs, StubsMode::Omit) {
uwriteln!(self.src_preamble, "// * stubs: {:?}", self.opts.stubs);
if matches!(self.opts.stubs, StubsMode::Separate) {
uwriteln!(self.stubs_src, "#[allow(warnings)]\n");
let name = &resolve.worlds[world].name;
let module_name = name.to_snake_case();
let module_name_ident = to_rust_module_raw(module_name.as_str());
if module_name_ident != to_rust_ident_raw(module_name.as_str()) {
// In no_std, `core` is automatically in scope at the crate root,
// so a `mod core;` here would conflict with the built-in `core` crate
uwriteln!(self.stubs_src, "#[path = \"{module_name}.rs\"]\n");
}
uwriteln!(
self.stubs_src,
r#"mod {module_name_ident};
#[allow(warnings)]
use crate::{module_name_ident}::*;"#
);
self.stubs_src.push_str("\n");
}
}
if let Some(export_prefix) = &self.opts.export_prefix {
uwriteln!(
Expand Down Expand Up @@ -1195,7 +1238,7 @@ impl WorldGenerator for RustWasm {
self.export_macros
.push((macro_name, self.interface_names[&id].path.clone()));

if self.opts.stubs {
if !matches!(self.opts.stubs, StubsMode::Omit) {
let world_id = self.world.unwrap();
let mut r#gen = self.interface(
Identifier::World(world_id),
Expand All @@ -1205,7 +1248,12 @@ impl WorldGenerator for RustWasm {
);
r#gen.generate_stub(Some((id, name)), resolve.interfaces[id].functions.values());
let stub = r#gen.finish();
self.src.push_str(&stub);
let stubs = match self.opts.stubs {
StubsMode::Omit => unsafe { unreachable_unchecked() },
StubsMode::Embedded => &mut self.src,
StubsMode::Separate => &mut self.stubs_src,
};
stubs.push_str(&stub);
}
Ok(())
}
Expand All @@ -1223,12 +1271,17 @@ impl WorldGenerator for RustWasm {
self.src.push_str(&src);
self.export_macros.push((macro_name, String::new()));

if self.opts.stubs {
if !matches!(self.opts.stubs, StubsMode::Omit) {
let mut r#gen =
self.interface(Identifier::World(world), "[export]$root", resolve, false);
r#gen.generate_stub(None, funcs.iter().map(|f| f.1));
let stub = r#gen.finish();
self.src.push_str(&stub);
let stubs = match self.opts.stubs {
StubsMode::Omit => unsafe { unreachable_unchecked() },
StubsMode::Embedded => &mut self.src,
StubsMode::Separate => &mut self.stubs_src,
};
stubs.push_str(&stub);
}
Ok(())
}
Expand Down Expand Up @@ -1342,8 +1395,13 @@ impl WorldGenerator for RustWasm {
},
);

if self.opts.stubs {
self.src.push_str("\n#[derive(Debug)]\npub struct Stub;\n");
if !matches!(self.opts.stubs, StubsMode::Omit) {
let stubs = match self.opts.stubs {
StubsMode::Omit => unsafe { unreachable_unchecked() },
StubsMode::Embedded => &mut self.src,
StubsMode::Separate => &mut self.stubs_src,
};
stubs.push_str("\n#[derive(Debug)]\npub struct Stub;\n");
}

let mut src = mem::take(&mut self.src);
Expand All @@ -1360,6 +1418,15 @@ impl WorldGenerator for RustWasm {
let module_name = name.to_snake_case();
files.push(&format!("{module_name}.rs"), src.as_bytes());

if matches!(self.opts.stubs, StubsMode::Separate) {
let mut src = mem::take(&mut self.stubs_src);
if self.opts.format {
let syntax_tree = syn::parse_file(src.as_str()).unwrap();
*src.as_mut_string() = prettyplease::unparse(&syntax_tree);
}
files.push(&format!("{module_name}_impl.rs"), src.as_bytes());
}

let remapped_keys = self
.with
.iter()
Expand Down Expand Up @@ -1483,6 +1550,50 @@ impl fmt::Display for Ownership {
}
}

#[derive(Default, Debug, Clone, Copy)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize),
serde(rename_all = "kebab-case")
)]
pub enum StubsMode {
/// Stubs will not be generated.
#[default]
Omit,

/// Stubs will be generated in the bindings file.
Embedded,

/// Stubs will be generated in a separate file.
Separate,
}

impl FromStr for StubsMode {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"omit" => Ok(Self::Omit),
"embedded" => Ok(Self::Embedded),
"separate" => Ok(Self::Separate),
_ => Err(format!(
"unrecognized stubsMode: `{s}`; \
expected 'omit', `embedded`, or `separate`"
)),
}
}
}

impl fmt::Display for StubsMode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(match self {
StubsMode::Omit => "omit",
StubsMode::Embedded => "embedded",
StubsMode::Separate => "separate",
})
}
}

/// Options for with "with" remappings.
#[derive(Debug, Clone)]
#[cfg_attr(
Expand Down Expand Up @@ -1533,61 +1644,44 @@ impl FnSig {
}
}

static RUST_KEYWORDS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
// Source: https://doc.rust-lang.org/reference/keywords.html
HashSet::from([
"as", "break", "const", "continue", "crate", "else", "enum", "extern", "false", "fn",
"for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub", "ref",
"return", "self", "Self", "static", "struct", "super", "trait", "true", "type", "unsafe",
"use", "where", "while", "async", "await", "dyn", "abstract", "become", "box", "do",
"final", "macro", "override", "priv", "typeof", "unsized", "virtual", "yield", "try",
"gen",
])
});

pub fn to_rust_ident(name: &str) -> String {
match name {
if RUST_KEYWORDS.contains(name) {
// Escape Rust keywords.
// Source: https://doc.rust-lang.org/reference/keywords.html
"as" => "as_".into(),
"break" => "break_".into(),
"const" => "const_".into(),
"continue" => "continue_".into(),
"crate" => "crate_".into(),
"else" => "else_".into(),
"enum" => "enum_".into(),
"extern" => "extern_".into(),
"false" => "false_".into(),
"fn" => "fn_".into(),
"for" => "for_".into(),
"if" => "if_".into(),
"impl" => "impl_".into(),
"in" => "in_".into(),
"let" => "let_".into(),
"loop" => "loop_".into(),
"match" => "match_".into(),
"mod" => "mod_".into(),
"move" => "move_".into(),
"mut" => "mut_".into(),
"pub" => "pub_".into(),
"ref" => "ref_".into(),
"return" => "return_".into(),
"self" => "self_".into(),
"static" => "static_".into(),
"struct" => "struct_".into(),
"super" => "super_".into(),
"trait" => "trait_".into(),
"true" => "true_".into(),
"type" => "type_".into(),
"unsafe" => "unsafe_".into(),
"use" => "use_".into(),
"where" => "where_".into(),
"while" => "while_".into(),
"async" => "async_".into(),
"await" => "await_".into(),
"dyn" => "dyn_".into(),
"abstract" => "abstract_".into(),
"become" => "become_".into(),
"box" => "box_".into(),
"do" => "do_".into(),
"final" => "final_".into(),
"macro" => "macro_".into(),
"override" => "override_".into(),
"priv" => "priv_".into(),
"typeof" => "typeof_".into(),
"unsized" => "unsized_".into(),
"virtual" => "virtual_".into(),
"yield" => "yield_".into(),
"try" => "try_".into(),
s => s.to_snake_case(),
format!("{}_", name)
} else {
name.to_snake_case()
}
}

pub fn to_rust_ident_raw(name: &str) -> String {
if RUST_KEYWORDS.contains(name) {
// Turn Rust keywords into raw identifiers.
format!("r#{}", name)
} else {
name.to_snake_case()
}
}

static RUST_CRATE_NAMES: LazyLock<HashSet<&'static str>> =
LazyLock::new(|| HashSet::from(["core", "alloc", "std"]));

pub fn to_rust_module_raw(name: &str) -> String {
if RUST_CRATE_NAMES.contains(name) {
format!("bindings_{}", name)
} else {
to_rust_ident_raw(name)
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/test/src/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ impl LanguageMethods for C {

fn codegen_test_variants(&self) -> &[(&str, &[&str])] {
&[
("base", &[]),
("no-sig-flattening", &["--no-sig-flattening"]),
("autodrop", &["--autodrop-borrows=yes"]),
("async", &["--async=all"]),
Expand Down
Loading
Loading