From 4f5288392b19739e17945ad85e8b57f41d0423a9 Mon Sep 17 00:00:00 2001 From: Joel Dice Date: Thu, 9 Jan 2025 12:31:10 -0700 Subject: [PATCH] fix a few stream/future issues (#1118) - The generated lift/lower code for stream/future payloads was not always calculating module paths correctly when generating type names. - Also, we were moving raw pointers into `async move` blocks and returning them without capturing the pointed-to memory. This would have been caught by runtime tests, but we don't have those yet since the Wasmtime async PR hasn't been merged yet. Fortunately, it was easy enough to find and fix when I updated that PR to use the latest wit-bindgen. - The generated lift/lower code for reading and writing streams needs to return a `Box` that captures the lifetimes of the parameters. Signed-off-by: Joel Dice --- .../rt/src/async_support/stream_support.rs | 4 +- crates/rust/src/bindgen.rs | 4 +- crates/rust/src/interface.rs | 116 +++++++++--------- tests/codegen/streams.wit | 15 +++ 4 files changed, 74 insertions(+), 65 deletions(-) diff --git a/crates/guest-rust/rt/src/async_support/stream_support.rs b/crates/guest-rust/rt/src/async_support/stream_support.rs index ec2ed596f..6e7ee7049 100644 --- a/crates/guest-rust/rt/src/async_support/stream_support.rs +++ b/crates/guest-rust/rt/src/async_support/stream_support.rs @@ -28,11 +28,11 @@ fn ceiling(x: usize, y: usize) -> usize { #[doc(hidden)] pub struct StreamVtable { - pub write: fn(future: u32, values: &[T]) -> Pin>>>, + pub write: fn(future: u32, values: &[T]) -> Pin> + '_>>, pub read: fn( future: u32, values: &mut [MaybeUninit], - ) -> Pin>>>, + ) -> Pin> + '_>>, pub cancel_write: fn(future: u32), pub cancel_read: fn(future: u32), pub close_writable: fn(future: u32), diff --git a/crates/rust/src/bindgen.rs b/crates/rust/src/bindgen.rs index 693dc2dd9..b06ea9304 100644 --- a/crates/rust/src/bindgen.rs +++ b/crates/rust/src/bindgen.rs @@ -475,7 +475,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { .as_ref() .map(|ty| { self.gen - .full_type_name_owned(ty, Identifier::StreamOrFuturePayload) + .type_name_owned_with_id(ty, Identifier::StreamOrFuturePayload) }) .unwrap_or_else(|| "()".into()); let ordinal = self.gen.gen.future_payloads.get_index_of(&name).unwrap(); @@ -496,7 +496,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { let op = &operands[0]; let name = self .gen - .full_type_name_owned(payload, Identifier::StreamOrFuturePayload); + .type_name_owned_with_id(payload, Identifier::StreamOrFuturePayload); let ordinal = self.gen.gen.stream_payloads.get_index_of(&name).unwrap(); let path = self.gen.path_to_root(); results.push(format!( diff --git a/crates/rust/src/interface.rs b/crates/rust/src/interface.rs index 192dd80ac..93e1c6063 100644 --- a/crates/rust/src/interface.rs +++ b/crates/rust/src/interface.rs @@ -483,6 +483,8 @@ macro_rules! {macro_name} {{ } fn generate_payloads(&mut self, prefix: &str, func: &Function, interface: Option<&WorldKey>) { + let old_identifier = mem::replace(&mut self.identifier, Identifier::StreamOrFuturePayload); + for (index, ty) in func .find_futures_and_streams(self.resolve) .into_iter() @@ -500,7 +502,7 @@ macro_rules! {macro_name} {{ match &self.resolve.types[ty].kind { TypeDefKind::Future(payload_type) => { let name = if let Some(payload_type) = payload_type { - self.full_type_name_owned(payload_type, Identifier::StreamOrFuturePayload) + self.type_name_owned(payload_type) } else { "()".into() }; @@ -533,7 +535,7 @@ macro_rules! {macro_name} {{ (String::new(), "let value = ();\n".into()) }; - let box_ = format!("super::super::{}", self.path_to_box()); + let box_ = self.path_to_box(); let code = format!( r#" #[doc(hidden)] @@ -545,7 +547,7 @@ pub mod vtable{ordinal} {{ }} #[cfg(target_arch = "wasm32")] - {{ + {box_}::pin(async move {{ #[repr(align({align}))] struct Buffer([::core::mem::MaybeUninit::; {size}]); let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]); @@ -558,10 +560,8 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8) -> u32; }} - {box_}::pin(async move {{ - unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} - }}) - }} + unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} + }}) }} fn read(future: u32) -> ::core::pin::Pin<{box_}>>> {{ @@ -571,7 +571,7 @@ pub mod vtable{ordinal} {{ }} #[cfg(target_arch = "wasm32")] - {{ + {box_}::pin(async move {{ struct Buffer([::core::mem::MaybeUninit::; {size}]); let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]); let address = buffer.0.as_mut_ptr() as *mut u8; @@ -582,15 +582,13 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8) -> u32; }} - {box_}::pin(async move {{ - if unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{ - {lift} - Some(value) - }} else {{ - None - }} - }}) - }} + if unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{ + {lift} + Some(value) + }} else {{ + None + }} + }}) }} fn cancel_write(writer: u32) {{ @@ -691,8 +689,7 @@ pub mod vtable{ordinal} {{ } } TypeDefKind::Stream(payload_type) => { - let name = - self.full_type_name_owned(payload_type, Identifier::StreamOrFuturePayload); + let name = self.type_name_owned(payload_type); if !self.gen.stream_payloads.contains_key(&name) { let ordinal = self.gen.stream_payloads.len(); @@ -747,19 +744,19 @@ for (index, dst) in values.iter_mut().take(count).enumerate() {{ (address.clone(), lower, address, lift) }; - let box_ = format!("super::super::{}", self.path_to_box()); + let box_ = self.path_to_box(); let code = format!( r#" #[doc(hidden)] pub mod vtable{ordinal} {{ - fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}>>> {{ + fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}> + '_>> {{ #[cfg(not(target_arch = "wasm32"))] {{ unreachable!(); }} #[cfg(target_arch = "wasm32")] - {{ + {box_}::pin(async move {{ {lower_address} {lower} @@ -769,27 +766,25 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8, _: u32) -> u32; }} - {box_}::pin(async move {{ - unsafe {{ - {async_support}::await_stream_result( - wit_import, - stream, - address, - u32::try_from(values.len()).unwrap() - ).await - }} - }}) - }} + unsafe {{ + {async_support}::await_stream_result( + wit_import, + stream, + address, + u32::try_from(values.len()).unwrap() + ).await + }} + }}) }} - fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}>>> {{ + fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}> + '_>> {{ #[cfg(not(target_arch = "wasm32"))] {{ unreachable!(); }} #[cfg(target_arch = "wasm32")] - {{ + {box_}::pin(async move {{ {lift_address} #[link(wasm_import_module = "{module}")] @@ -798,22 +793,20 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8, _: u32) -> u32; }} - {box_}::pin(async move {{ - let count = unsafe {{ - {async_support}::await_stream_result( - wit_import, - stream, - address, - u32::try_from(values.len()).unwrap() - ).await - }}; - #[allow(unused)] - if let Some(count) = count {{ - {lift} - }} - count - }}) - }} + let count = unsafe {{ + {async_support}::await_stream_result( + wit_import, + stream, + address, + u32::try_from(values.len()).unwrap() + ).await + }}; + #[allow(unused)] + if let Some(count) = count {{ + {lift} + }} + count + }}) }} fn cancel_write(writer: u32) {{ @@ -916,6 +909,8 @@ pub mod vtable{ordinal} {{ _ => unreachable!(), } } + + self.identifier = old_identifier; } fn generate_guest_import(&mut self, func: &Function, interface: Option<&WorldKey>) { @@ -1699,25 +1694,24 @@ pub mod vtable{ordinal} {{ } } - pub(crate) fn full_type_name_owned(&mut self, ty: &Type, id: Identifier<'i>) -> String { - self.full_type_name( + pub(crate) fn type_name_owned_with_id(&mut self, ty: &Type, id: Identifier<'i>) -> String { + let old_identifier = mem::replace(&mut self.identifier, id); + let name = self.type_name_owned(ty); + self.identifier = old_identifier; + name + } + + fn type_name_owned(&mut self, ty: &Type) -> String { + self.type_name( ty, TypeMode { lifetime: None, lists_borrowed: false, style: TypeOwnershipStyle::Owned, }, - id, ) } - fn full_type_name(&mut self, ty: &Type, mode: TypeMode, id: Identifier<'i>) -> String { - let old_identifier = mem::replace(&mut self.identifier, id); - let name = self.type_name(ty, mode); - self.identifier = old_identifier; - name - } - fn type_name(&mut self, ty: &Type, mode: TypeMode) -> String { let old = mem::take(&mut self.src); self.print_ty(ty, mode); diff --git a/tests/codegen/streams.wit b/tests/codegen/streams.wit index fd00239b7..7ed696ed8 100644 --- a/tests/codegen/streams.wit +++ b/tests/codegen/streams.wit @@ -1,5 +1,19 @@ package foo:foo; +interface transmit { + variant control { + read-stream(string), + read-future(string), + write-stream(string), + write-future(string), + } + + exchange: func(control: stream, + caller-stream: stream, + caller-future1: future, + caller-future2: future) -> tuple, future, future>; +} + interface streams { stream-u8-param: func(x: stream); stream-u16-param: func(x: stream); @@ -82,4 +96,5 @@ interface streams { world the-streams { import streams; export streams; + export transmit; }