From 5a8a1472a4a6f3804a07d863657c8bc4c81cd313 Mon Sep 17 00:00:00 2001 From: Joel Dice Date: Thu, 9 Jan 2025 16:19:24 -0700 Subject: [PATCH 1/2] enable stream/future payload lift/lower for non-Wasm platforms Previously, we generated no code for non-Wasm platforms; which meant our codegen tests weren't really testing much as far as streams and futures go. Now that the tests actually do something, they uncovered a few issues which I've fixed: - Invalid code generation when using `duplicate_if_necessary: true` - Invalid code generation for stream or futures whose payloads contain one or more a streams or futures For the latter, I mimicked what we do for resources: use interior mutability to provide `take_handle` methods for `StreamReader` and `FutureReader`. Signed-off-by: Joel Dice --- .../rt/src/async_support/future_support.rs | 72 +++++--- .../rt/src/async_support/stream_support.rs | 171 ++++++++++-------- crates/rust/src/bindgen.rs | 16 +- crates/rust/src/interface.rs | 58 +++--- 4 files changed, 174 insertions(+), 143 deletions(-) diff --git a/crates/guest-rust/rt/src/async_support/future_support.rs b/crates/guest-rust/rt/src/async_support/future_support.rs index 93cf71872..8477ec481 100644 --- a/crates/guest-rust/rt/src/async_support/future_support.rs +++ b/crates/guest-rust/rt/src/async_support/future_support.rs @@ -11,8 +11,8 @@ use { collections::hash_map::Entry, fmt, future::{Future, IntoFuture}, - mem::ManuallyDrop, pin::Pin, + sync::atomic::{AtomicU32, Ordering::Relaxed}, task::{Context, Poll}, }, }; @@ -199,7 +199,8 @@ impl CancelableRead { fn cancel_mut(&mut self) -> FutureReader { let reader = self.reader.take().unwrap(); - super::with_entry(reader.handle, |entry| match entry { + let handle = reader.handle.load(Relaxed); + super::with_entry(handle, |entry| match entry { Entry::Vacant(_) => unreachable!(), Entry::Occupied(mut entry) => match entry.get() { Handle::LocalOpen @@ -209,7 +210,7 @@ impl CancelableRead { Handle::LocalWaiting(_) => { entry.insert(Handle::LocalOpen); } - Handle::Read => (reader.vtable.cancel_read)(reader.handle), + Handle::Read => (reader.vtable.cancel_read)(handle), }, }); reader @@ -226,7 +227,7 @@ impl Drop for CancelableRead { /// Represents the readable end of a Component Model `future`. pub struct FutureReader { - handle: u32, + handle: AtomicU32, vtable: &'static FutureVtable, } @@ -241,7 +242,10 @@ impl fmt::Debug for FutureReader { impl FutureReader { #[doc(hidden)] pub fn new(handle: u32, vtable: &'static FutureVtable) -> Self { - Self { handle, vtable } + Self { + handle: AtomicU32::new(handle), + vtable, + } } #[doc(hidden)] @@ -264,12 +268,16 @@ impl FutureReader { }, }); - Self { handle, vtable } + Self { + handle: AtomicU32::new(handle), + vtable, + } } #[doc(hidden)] - pub fn into_handle(self) -> u32 { - super::with_entry(self.handle, |entry| match entry { + pub fn take_handle(&self) -> u32 { + let handle = self.handle.swap(u32::MAX, Relaxed); + super::with_entry(handle, |entry| match entry { Entry::Vacant(_) => unreachable!(), Entry::Occupied(mut entry) => match entry.get() { Handle::LocalOpen => { @@ -282,7 +290,7 @@ impl FutureReader { }, }); - ManuallyDrop::new(self).handle + handle } } @@ -294,7 +302,7 @@ impl IntoFuture for FutureReader { /// written to the writable end of this `future` (yielding a `Some` result) /// or when the writable end is dropped (yielding a `None` result). fn into_future(self) -> Self::IntoFuture { - let handle = self.handle; + let handle = self.handle.load(Relaxed); let vtable = self.vtable; CancelableRead { reader: Some(self), @@ -325,24 +333,30 @@ impl IntoFuture for FutureReader { impl Drop for FutureReader { fn drop(&mut self) { - super::with_entry(self.handle, |entry| match entry { - Entry::Vacant(_) => unreachable!(), - Entry::Occupied(mut entry) => match entry.get_mut() { - Handle::LocalReady(..) => { - let Handle::LocalReady(_, waker) = entry.insert(Handle::LocalClosed) else { - unreachable!() - }; - waker.wake(); - } - Handle::LocalOpen | Handle::LocalWaiting(_) => { - entry.insert(Handle::LocalClosed); - } - Handle::Read | Handle::LocalClosed => { - entry.remove(); - (self.vtable.close_readable)(self.handle); - } - Handle::Write => unreachable!(), - }, - }); + match self.handle.load(Relaxed) { + u32::MAX => {} + handle => { + super::with_entry(handle, |entry| match entry { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(mut entry) => match entry.get_mut() { + Handle::LocalReady(..) => { + let Handle::LocalReady(_, waker) = entry.insert(Handle::LocalClosed) + else { + unreachable!() + }; + waker.wake(); + } + Handle::LocalOpen | Handle::LocalWaiting(_) => { + entry.insert(Handle::LocalClosed); + } + Handle::Read | Handle::LocalClosed => { + entry.remove(); + (self.vtable.close_readable)(handle); + } + Handle::Write => unreachable!(), + }, + }); + } + } } } diff --git a/crates/guest-rust/rt/src/async_support/stream_support.rs b/crates/guest-rust/rt/src/async_support/stream_support.rs index 6e7ee7049..06aec8423 100644 --- a/crates/guest-rust/rt/src/async_support/stream_support.rs +++ b/crates/guest-rust/rt/src/async_support/stream_support.rs @@ -15,8 +15,9 @@ use { fmt, future::Future, iter, - mem::{self, ManuallyDrop, MaybeUninit}, + mem::{self, MaybeUninit}, pin::Pin, + sync::atomic::{AtomicU32, Ordering::Relaxed}, task::{Context, Poll}, vec::Vec, }, @@ -246,7 +247,7 @@ impl Drop for CancelReadOnDrop { /// Represents the readable end of a Component Model `stream`. pub struct StreamReader { - handle: u32, + handle: AtomicU32, future: Option>> + 'static>>>, vtable: &'static StreamVtable, } @@ -273,7 +274,7 @@ impl StreamReader { #[doc(hidden)] pub fn new(handle: u32, vtable: &'static StreamVtable) -> Self { Self { - handle, + handle: AtomicU32::new(handle), future: None, vtable, } @@ -300,15 +301,16 @@ impl StreamReader { }); Self { - handle, + handle: AtomicU32::new(handle), future: None, vtable, } } #[doc(hidden)] - pub fn into_handle(self) -> u32 { - super::with_entry(self.handle, |entry| match entry { + pub fn take_handle(&self) -> u32 { + let handle = self.handle.swap(u32::MAX, Relaxed); + super::with_entry(handle, |entry| match entry { Entry::Vacant(_) => unreachable!(), Entry::Occupied(mut entry) => match entry.get() { Handle::LocalOpen => { @@ -321,7 +323,7 @@ impl StreamReader { }, }); - ManuallyDrop::new(self).handle + handle } } @@ -332,60 +334,65 @@ impl Stream for StreamReader { let me = self.get_mut(); if me.future.is_none() { - me.future = Some(super::with_entry(me.handle, |entry| match entry { - Entry::Vacant(_) => unreachable!(), - Entry::Occupied(mut entry) => match entry.get() { - Handle::Write | Handle::LocalWaiting(_) => unreachable!(), - Handle::Read => { - let handle = me.handle; - let vtable = me.vtable; - let mut cancel_on_drop = CancelReadOnDrop:: { - handle: Some(handle), - vtable, - }; - Box::pin(async move { - let mut buffer = iter::repeat_with(MaybeUninit::uninit) - .take(ceiling(64 * 1024, mem::size_of::())) - .collect::>(); - - let result = - if let Some(count) = (vtable.read)(handle, &mut buffer).await { - buffer.truncate(count); - Some(unsafe { - mem::transmute::>, Vec>(buffer) - }) - } else { - None - }; - cancel_on_drop.handle = None; - drop(cancel_on_drop); - result - }) as Pin>> - } - Handle::LocalOpen => { - let (tx, rx) = oneshot::channel(); - entry.insert(Handle::LocalWaiting(tx)); - let mut cancel_on_drop = CancelReadOnDrop:: { - handle: Some(me.handle), - vtable: me.vtable, - }; - Box::pin(async move { - let result = rx.map(|v| v.ok().map(|v| *v.downcast().unwrap())).await; - cancel_on_drop.handle = None; - drop(cancel_on_drop); - result - }) - } - Handle::LocalClosed => Box::pin(future::ready(None)), - Handle::LocalReady(..) => { - let Handle::LocalReady(v, waker) = entry.insert(Handle::LocalOpen) else { - unreachable!() - }; - waker.wake(); - Box::pin(future::ready(Some(*v.downcast().unwrap()))) - } + me.future = Some(super::with_entry( + me.handle.load(Relaxed), + |entry| match entry { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(mut entry) => match entry.get() { + Handle::Write | Handle::LocalWaiting(_) => unreachable!(), + Handle::Read => { + let handle = me.handle.load(Relaxed); + let vtable = me.vtable; + let mut cancel_on_drop = CancelReadOnDrop:: { + handle: Some(handle), + vtable, + }; + Box::pin(async move { + let mut buffer = iter::repeat_with(MaybeUninit::uninit) + .take(ceiling(64 * 1024, mem::size_of::())) + .collect::>(); + + let result = + if let Some(count) = (vtable.read)(handle, &mut buffer).await { + buffer.truncate(count); + Some(unsafe { + mem::transmute::>, Vec>(buffer) + }) + } else { + None + }; + cancel_on_drop.handle = None; + drop(cancel_on_drop); + result + }) as Pin>> + } + Handle::LocalOpen => { + let (tx, rx) = oneshot::channel(); + entry.insert(Handle::LocalWaiting(tx)); + let mut cancel_on_drop = CancelReadOnDrop:: { + handle: Some(me.handle.load(Relaxed)), + vtable: me.vtable, + }; + Box::pin(async move { + let result = + rx.map(|v| v.ok().map(|v| *v.downcast().unwrap())).await; + cancel_on_drop.handle = None; + drop(cancel_on_drop); + result + }) + } + Handle::LocalClosed => Box::pin(future::ready(None)), + Handle::LocalReady(..) => { + let Handle::LocalReady(v, waker) = entry.insert(Handle::LocalOpen) + else { + unreachable!() + }; + waker.wake(); + Box::pin(future::ready(Some(*v.downcast().unwrap()))) + } + }, }, - })); + )); } match me.future.as_mut().unwrap().as_mut().poll(cx) { @@ -402,24 +409,30 @@ impl Drop for StreamReader { fn drop(&mut self) { self.future = None; - super::with_entry(self.handle, |entry| match entry { - Entry::Vacant(_) => unreachable!(), - Entry::Occupied(mut entry) => match entry.get_mut() { - Handle::LocalReady(..) => { - let Handle::LocalReady(_, waker) = entry.insert(Handle::LocalClosed) else { - unreachable!() - }; - waker.wake(); - } - Handle::LocalOpen | Handle::LocalWaiting(_) => { - entry.insert(Handle::LocalClosed); - } - Handle::Read | Handle::LocalClosed => { - entry.remove(); - (self.vtable.close_readable)(self.handle); - } - Handle::Write => unreachable!(), - }, - }); + match self.handle.load(Relaxed) { + u32::MAX => {} + handle => { + super::with_entry(handle, |entry| match entry { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(mut entry) => match entry.get_mut() { + Handle::LocalReady(..) => { + let Handle::LocalReady(_, waker) = entry.insert(Handle::LocalClosed) + else { + unreachable!() + }; + waker.wake(); + } + Handle::LocalOpen | Handle::LocalWaiting(_) => { + entry.insert(Handle::LocalClosed); + } + Handle::Read | Handle::LocalClosed => { + entry.remove(); + (self.vtable.close_readable)(handle); + } + Handle::Write => unreachable!(), + }, + }); + } + } } } diff --git a/crates/rust/src/bindgen.rs b/crates/rust/src/bindgen.rs index b06ea9304..8d3c10269 100644 --- a/crates/rust/src/bindgen.rs +++ b/crates/rust/src/bindgen.rs @@ -19,6 +19,7 @@ pub(super) struct FunctionBindgen<'a, 'b> { pub import_return_pointer_area_size: usize, pub import_return_pointer_area_align: usize, pub handle_decls: Vec, + always_owned: bool, } impl<'a, 'b> FunctionBindgen<'a, 'b> { @@ -27,6 +28,7 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> { params: Vec, async_: bool, wasm_import_module: &'b str, + always_owned: bool, ) -> FunctionBindgen<'a, 'b> { FunctionBindgen { gen, @@ -42,6 +44,7 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> { import_return_pointer_area_size: 0, import_return_pointer_area_align: 0, handle_decls: Vec::new(), + always_owned, } } @@ -190,10 +193,11 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> { } fn typename_lower(&self, id: TypeId) -> String { - let owned = match self.lift_lower() { - LiftLower::LowerArgsLiftResults => false, - LiftLower::LiftArgsLowerResults => true, - }; + let owned = self.always_owned + || match self.lift_lower() { + LiftLower::LowerArgsLiftResults => false, + LiftLower::LiftArgsLowerResults => true, + }; self.gen.type_path(id, owned) } @@ -465,7 +469,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { Instruction::FutureLower { .. } => { let op = &operands[0]; - results.push(format!("({op}).into_handle() as i32")) + results.push(format!("({op}).take_handle() as i32")) } Instruction::FutureLift { payload, .. } => { @@ -488,7 +492,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { Instruction::StreamLower { .. } => { let op = &operands[0]; - results.push(format!("({op}).into_handle() as i32")) + results.push(format!("({op}).take_handle() as i32")) } Instruction::StreamLift { payload, .. } => { diff --git a/crates/rust/src/interface.rs b/crates/rust/src/interface.rs index ac6c3a148..a5dfb1924 100644 --- a/crates/rust/src/interface.rs +++ b/crates/rust/src/interface.rs @@ -541,12 +541,6 @@ macro_rules! {macro_name} {{ #[doc(hidden)] pub mod vtable{ordinal} {{ fn write(future: u32, value: {name}) -> ::core::pin::Pin<{box_}>> {{ - #[cfg(not(target_arch = "wasm32"))] - {{ - unreachable!(); - }} - - #[cfg(target_arch = "wasm32")] {box_}::pin(async move {{ #[repr(align({align}))] struct Buffer([::core::mem::MaybeUninit::; {size}]); @@ -554,6 +548,12 @@ pub mod vtable{ordinal} {{ let address = buffer.0.as_mut_ptr() as *mut u8; {lower} + #[cfg(not(target_arch = "wasm32"))] + unsafe extern "C" fn wit_import(_: u32, _: *mut u8) -> u32 {{ + unreachable!() + }} + + #[cfg(target_arch = "wasm32")] #[link(wasm_import_module = "{module}")] extern "C" {{ #[link_name = "[async][future-write-{index}]{func_name}"] @@ -565,17 +565,17 @@ pub mod vtable{ordinal} {{ }} fn read(future: u32) -> ::core::pin::Pin<{box_}>>> {{ - #[cfg(not(target_arch = "wasm32"))] - {{ - unreachable!(); - }} - - #[cfg(target_arch = "wasm32")] {box_}::pin(async move {{ struct Buffer([::core::mem::MaybeUninit::; {size}]); let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]); let address = buffer.0.as_mut_ptr() as *mut u8; + #[cfg(not(target_arch = "wasm32"))] + unsafe extern "C" fn wit_import(_: u32, _: *mut u8) -> u32 {{ + unreachable!() + }} + + #[cfg(target_arch = "wasm32")] #[link(wasm_import_module = "{module}")] extern "C" {{ #[link_name = "[async][future-read-{index}]{func_name}"] @@ -750,16 +750,16 @@ for (index, dst) in values.iter_mut().take(count).enumerate() {{ #[doc(hidden)] pub mod vtable{ordinal} {{ fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}> + '_>> {{ - #[cfg(not(target_arch = "wasm32"))] - {{ - unreachable!(); - }} - - #[cfg(target_arch = "wasm32")] {box_}::pin(async move {{ {lower_address} {lower} + #[cfg(not(target_arch = "wasm32"))] + unsafe extern "C" fn wit_import(_: u32, _: *mut u8, _: u32) -> u32 {{ + unreachable!() + }} + + #[cfg(target_arch = "wasm32")] #[link(wasm_import_module = "{module}")] extern "C" {{ #[link_name = "[async][stream-write-{index}]{func_name}"] @@ -778,15 +778,15 @@ pub mod vtable{ordinal} {{ }} fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}> + '_>> {{ - #[cfg(not(target_arch = "wasm32"))] - {{ - unreachable!(); - }} - - #[cfg(target_arch = "wasm32")] {box_}::pin(async move {{ {lift_address} + #[cfg(not(target_arch = "wasm32"))] + unsafe extern "C" fn wit_import(_: u32, _: *mut u8, _: u32) -> u32 {{ + unreachable!() + }} + + #[cfg(target_arch = "wasm32")] #[link(wasm_import_module = "{module}")] extern "C" {{ #[link_name = "[async][stream-read-{index}]{func_name}"] @@ -965,13 +965,13 @@ pub mod vtable{ordinal} {{ } fn lower_to_memory(&mut self, address: &str, value: &str, ty: &Type, module: &str) -> String { - let mut f = FunctionBindgen::new(self, Vec::new(), true, module); + let mut f = FunctionBindgen::new(self, Vec::new(), true, module, true); abi::lower_to_memory(f.gen.resolve, &mut f, address.into(), value.into(), ty); format!("unsafe {{ {} }}", String::from(f.src)) } fn lift_from_memory(&mut self, address: &str, value: &str, ty: &Type, module: &str) -> String { - let mut f = FunctionBindgen::new(self, Vec::new(), true, module); + let mut f = FunctionBindgen::new(self, Vec::new(), true, module, true); let result = abi::lift_from_memory(f.gen.resolve, &mut f, address.into(), ty); format!( "let {value} = unsafe {{ {}\n{result} }};", @@ -986,7 +986,7 @@ pub mod vtable{ordinal} {{ params: Vec, async_: bool, ) { - let mut f = FunctionBindgen::new(self, params, async_, module); + let mut f = FunctionBindgen::new(self, params, async_, module, false); abi::call( f.gen.resolve, AbiVariant::GuestImport, @@ -1060,7 +1060,7 @@ pub mod vtable{ordinal} {{ ); } - let mut f = FunctionBindgen::new(self, params, async_, self.wasm_import_module); + let mut f = FunctionBindgen::new(self, params, async_, self.wasm_import_module, false); abi::call( f.gen.resolve, AbiVariant::GuestExport, @@ -1107,7 +1107,7 @@ pub mod vtable{ordinal} {{ let params = self.print_post_return_sig(func); self.src.push_str("{\n"); - let mut f = FunctionBindgen::new(self, params, async_, self.wasm_import_module); + let mut f = FunctionBindgen::new(self, params, async_, self.wasm_import_module, false); abi::post_return(f.gen.resolve, func, &mut f, async_); let FunctionBindgen { needs_cleanup_list, From acb8bfaef37a81f7e1b8e0b0fb5dadc25d6b1092 Mon Sep 17 00:00:00 2001 From: Joel Dice Date: Fri, 10 Jan 2025 07:35:03 -0700 Subject: [PATCH 2/2] avoid lowering same values more than once in `StreamVtable::write` Previously, we would optimistically lower all the values in the input array and then re-lower the subset which wasn't accepted the first time. Aside from being inefficient, that was also incorrect since re-lowering would fail in the cases of any resource handles, futures, or streams in the payload since we would have already taken the handles using `take_handle`. Signed-off-by: Joel Dice --- .../rt/src/async_support/stream_support.rs | 11 ++----- crates/rust/src/interface.rs | 32 +++++++++++++------ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/crates/guest-rust/rt/src/async_support/stream_support.rs b/crates/guest-rust/rt/src/async_support/stream_support.rs index 06aec8423..c8e0a96b7 100644 --- a/crates/guest-rust/rt/src/async_support/stream_support.rs +++ b/crates/guest-rust/rt/src/async_support/stream_support.rs @@ -29,7 +29,7 @@ fn ceiling(x: usize, y: usize) -> usize { #[doc(hidden)] pub struct StreamVtable { - pub write: fn(future: u32, values: &[T]) -> Pin> + '_>>, + pub write: fn(future: u32, values: &[T]) -> Pin + '_>>, pub read: fn( future: u32, values: &mut [MaybeUninit], @@ -174,14 +174,7 @@ impl Sink> for StreamWriter { vtable, }; self.get_mut().future = Some(Box::pin(async move { - let mut offset = 0; - while offset < item.len() { - if let Some(count) = (vtable.write)(handle, &item[offset..]).await { - offset += count; - } else { - break; - } - } + (vtable.write)(handle, &item).await; cancel_on_drop.handle = None; drop(cancel_on_drop); })); diff --git a/crates/rust/src/interface.rs b/crates/rust/src/interface.rs index a5dfb1924..62ef4a286 100644 --- a/crates/rust/src/interface.rs +++ b/crates/rust/src/interface.rs @@ -698,8 +698,10 @@ pub mod vtable{ordinal} {{ let alloc = self.path_to_std_alloc_module(); let (lower_address, lower, lift_address, lift) = if stream_direct(payload_type) { - let lower_address = "let address = values.as_ptr() as _;".into(); - let lift_address = "let address = values.as_mut_ptr() as _;".into(); + let lower_address = + "let address = values.as_ptr() as *mut u8;".into(); + let lift_address = + "let address = values.as_mut_ptr() as *mut u8;".into(); ( lower_address, String::new(), @@ -749,7 +751,7 @@ for (index, dst) in values.iter_mut().take(count).enumerate() {{ r#" #[doc(hidden)] pub mod vtable{ordinal} {{ - fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}> + '_>> {{ + fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_} + '_>> {{ {box_}::pin(async move {{ {lower_address} {lower} @@ -766,14 +768,24 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8, _: u32) -> u32; }} - unsafe {{ - {async_support}::await_stream_result( - wit_import, - stream, - address, - u32::try_from(values.len()).unwrap() - ).await + let mut total = 0; + while total < values.len() {{ + let count = unsafe {{ + {async_support}::await_stream_result( + wit_import, + stream, + address.add(total * {size}), + u32::try_from(values.len()).unwrap() + ).await + }}; + + if let Some(count) = count {{ + total += count; + }} else {{ + break + }} }} + total }}) }}