Skip to content

Commit

Permalink
fix async tests when component-model-async enabled
Browse files Browse the repository at this point in the history
Enabling this feature for all tests revealed various missing pieces in the new
`concurrent.rs` fiber mechanism, which I've addressed.

This adds a bunch of ugly `#[cfg(feature = "component-model-async")]` guards,
but those will all go away once I unify the two async fiber implementations.

Signed-off-by: Joel Dice <[email protected]>
  • Loading branch information
dicej committed Dec 21, 2024
1 parent c166a9f commit d35d87c
Show file tree
Hide file tree
Showing 16 changed files with 595 additions and 218 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ rustix = { workspace = true, features = ["mm", "param", "process"] }

[dev-dependencies]
# depend again on wasmtime to activate its default features for tests
wasmtime = { workspace = true, features = ['default', 'winch', 'pulley', 'all-arch', 'call-hook', 'memory-protection-keys', 'signals-based-traps'] }
wasmtime = { workspace = true, features = ['default', 'winch', 'pulley', 'all-arch', 'call-hook', 'memory-protection-keys', 'signals-based-traps', 'component-model-async'] }
env_logger = { workspace = true }
log = { workspace = true }
filecheck = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion benches/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ fn bench_host_to_wasm<Params, Results>(
typed_results: Results,
) where
Params: WasmParams + ToVals + Copy,
Results: WasmResults + ToVals + Copy + PartialEq + Debug,
Results: WasmResults + ToVals + Copy + PartialEq + Debug + Sync + 'static,
{
// Benchmark the "typed" version, which should be faster than the versions
// below.
Expand Down
165 changes: 113 additions & 52 deletions crates/wasmtime/src/runtime/component/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use {
future::Future,
marker::PhantomData,
mem::{self, MaybeUninit},
ops::Range,
pin::{pin, Pin},
ptr::{self, NonNull},
sync::{Arc, Mutex},
Expand Down Expand Up @@ -323,6 +324,23 @@ impl<T: Copy> Drop for Reset<T> {
}
}

#[derive(Clone, Copy)]
struct PollContext {
future_context: *mut Context<'static>,
guard_range_start: *mut u8,
guard_range_end: *mut u8,
}

impl Default for PollContext {
fn default() -> PollContext {
PollContext {
future_context: core::ptr::null_mut(),
guard_range_start: core::ptr::null_mut(),
guard_range_end: core::ptr::null_mut(),
}
}
}

struct AsyncState {
current_suspend: UnsafeCell<
*mut Suspend<
Expand All @@ -331,7 +349,7 @@ struct AsyncState {
(Option<*mut dyn VMStore>, Result<()>),
>,
>,
current_poll_cx: UnsafeCell<*mut Context<'static>>,
current_poll_cx: UnsafeCell<PollContext>,
}

unsafe impl Send for AsyncState {}
Expand All @@ -344,26 +362,35 @@ pub(crate) struct AsyncCx {
(Option<*mut dyn VMStore>, Result<()>),
>,
current_stack_limit: *mut usize,
current_poll_cx: *mut *mut Context<'static>,
current_poll_cx: *mut PollContext,
track_pkey_context_switch: bool,
}

impl AsyncCx {
pub(crate) fn new<T>(store: &mut StoreContextMut<T>) -> Self {
Self {
current_suspend: store.concurrent_state().async_state.current_suspend.get(),
current_stack_limit: store.0.runtime_limits().stack_limit.get(),
current_poll_cx: store.concurrent_state().async_state.current_poll_cx.get(),
track_pkey_context_switch: store.has_pkey(),
Self::try_new(store).unwrap()
}

pub(crate) fn try_new<T>(store: &mut StoreContextMut<T>) -> Option<Self> {
let current_poll_cx = store.concurrent_state().async_state.current_poll_cx.get();
if unsafe { (*current_poll_cx).future_context.is_null() } {
None
} else {
Some(Self {
current_suspend: store.concurrent_state().async_state.current_suspend.get(),
current_stack_limit: store.0.runtime_limits().stack_limit.get(),
current_poll_cx,
track_pkey_context_switch: store.has_pkey(),
})
}
}

unsafe fn poll<U>(&self, mut future: Pin<&mut (dyn Future<Output = U> + Send)>) -> Poll<U> {
let poll_cx = *self.current_poll_cx;
let _reset = Reset(self.current_poll_cx, poll_cx);
*self.current_poll_cx = ptr::null_mut();
assert!(!poll_cx.is_null());
future.as_mut().poll(&mut *poll_cx)
*self.current_poll_cx = PollContext::default();
assert!(!poll_cx.future_context.is_null());
future.as_mut().poll(&mut *poll_cx.future_context)
}

pub(crate) unsafe fn block_on<'a, T, U>(
Expand Down Expand Up @@ -420,6 +447,13 @@ pub struct ConcurrentState<T> {
_phantom: PhantomData<T>,
}

impl<T> ConcurrentState<T> {
pub(crate) fn async_guard_range(&self) -> Range<*mut u8> {
let context = unsafe { *self.async_state.current_poll_cx.get() };
context.guard_range_start..context.guard_range_end
}
}

impl<T> Default for ConcurrentState<T> {
fn default() -> Self {
Self {
Expand All @@ -428,7 +462,7 @@ impl<T> Default for ConcurrentState<T> {
futures: ReadyChunks::new(FuturesUnordered::new(), 1024),
async_state: AsyncState {
current_suspend: UnsafeCell::new(ptr::null_mut()),
current_poll_cx: UnsafeCell::new(ptr::null_mut()),
current_poll_cx: UnsafeCell::new(PollContext::default()),
},
instance_states: HashMap::new(),
yielding: HashSet::new(),
Expand Down Expand Up @@ -622,7 +656,7 @@ pub(crate) fn poll_and_block<'a, T, R: Send + Sync + 'static>(

pub(crate) async fn on_fiber<'a, R: Send + Sync + 'static, T: Send>(
mut store: StoreContextMut<'a, T>,
instance: RuntimeComponentInstanceIndex,
instance: Option<RuntimeComponentInstanceIndex>,
func: impl FnOnce(&mut StoreContextMut<T>) -> R + Send,
) -> Result<(R, StoreContextMut<'a, T>)> {
let result = Arc::new(Mutex::new(None));
Expand All @@ -634,7 +668,21 @@ pub(crate) async fn on_fiber<'a, R: Send + Sync + 'static, T: Send>(
}
})?;

store = poll_fn(store, move |_, mut store| {
let guard_range = fiber
.fiber
.as_ref()
.unwrap()
.stack()
.guard_range()
.map(|r| {
(
NonNull::new(r.start).map(SendSyncPtr::new),
NonNull::new(r.end).map(SendSyncPtr::new),
)
})
.unwrap_or((None, None));

store = poll_fn(store, guard_range, move |_, mut store| {
match resume_fiber(&mut fiber, store.take(), Ok(())) {
Ok(Ok((store, result))) => Ok(result.map(|()| store)),
Ok(Err(s)) => Err(s),
Expand Down Expand Up @@ -761,36 +809,40 @@ fn resume_stackful<'a, T>(
match resume_fiber(&mut fiber, Some(store), Ok(()))? {
Ok((mut store, result)) => {
result?;
store = maybe_resume_next_task(store, fiber.instance)?;
for (event, call, _) in mem::take(
&mut store
.concurrent_state()
.table
.get_mut(guest_task)
.with_context(|| format!("bad handle: {}", guest_task.rep()))?
.events,
) {
if event == events::EVENT_CALL_DONE {
log::trace!("resume_stackful will delete call {}", call.rep());
call.delete_all_from(store.as_context_mut())?;
}
}
match &store.concurrent_state().table.get(guest_task)?.caller {
Caller::Host(_) => {
log::trace!("resume_stackful will delete task {}", guest_task.rep());
AnyTask::Guest(guest_task).delete_all_from(store.as_context_mut())?;
Ok(store)
if let Some(instance) = fiber.instance {
store = maybe_resume_next_task(store, instance)?;
for (event, call, _) in mem::take(
&mut store
.concurrent_state()
.table
.get_mut(guest_task)
.with_context(|| format!("bad handle: {}", guest_task.rep()))?
.events,
) {
if event == events::EVENT_CALL_DONE {
log::trace!("resume_stackful will delete call {}", call.rep());
call.delete_all_from(store.as_context_mut())?;
}
}
Caller::Guest { task, .. } => {
let task = *task;
maybe_send_event(
store,
task,
events::EVENT_CALL_DONE,
AnyTask::Guest(guest_task),
0,
)
match &store.concurrent_state().table.get(guest_task)?.caller {
Caller::Host(_) => {
log::trace!("resume_stackful will delete task {}", guest_task.rep());
AnyTask::Guest(guest_task).delete_all_from(store.as_context_mut())?;
Ok(store)
}
Caller::Guest { task, .. } => {
let task = *task;
maybe_send_event(
store,
task,
events::EVENT_CALL_DONE,
AnyTask::Guest(guest_task),
0,
)
}
}
} else {
Ok(store)
}
}
Err(new_store) => {
Expand Down Expand Up @@ -1029,7 +1081,7 @@ struct StoreFiber<'a> {
(Option<*mut dyn VMStore>, Result<()>),
>,
stack_limit: *mut usize,
instance: RuntimeComponentInstanceIndex,
instance: Option<RuntimeComponentInstanceIndex>,
}

impl<'a> Drop for StoreFiber<'a> {
Expand All @@ -1054,7 +1106,7 @@ unsafe impl<'a> Sync for StoreFiber<'a> {}

fn make_fiber<'a, T>(
store: &mut StoreContextMut<T>,
instance: RuntimeComponentInstanceIndex,
instance: Option<RuntimeComponentInstanceIndex>,
fun: impl FnOnce(StoreContextMut<T>) -> Result<()> + 'a,
) -> Result<StoreFiber<'a>> {
let engine = store.engine().clone();
Expand Down Expand Up @@ -1118,9 +1170,11 @@ unsafe fn resume_fiber_raw<'a>(
fn poll_ready<'a, T>(mut store: StoreContextMut<'a, T>) -> Result<StoreContextMut<'a, T>> {
unsafe {
let cx = *store.concurrent_state().async_state.current_poll_cx.get();
assert!(!cx.is_null());
while let Poll::Ready(Some(ready)) =
store.concurrent_state().futures.poll_next_unpin(&mut *cx)
assert!(!cx.future_context.is_null());
while let Poll::Ready(Some(ready)) = store
.concurrent_state()
.futures
.poll_next_unpin(&mut *cx.future_context)
{
match handle_ready(store, ready) {
Ok(s) => {
Expand Down Expand Up @@ -1691,7 +1745,7 @@ fn do_start_call<'a, T>(
cx
}
} else {
let mut fiber = make_fiber(&mut cx, callee_instance, move |mut cx| {
let mut fiber = make_fiber(&mut cx, Some(callee_instance), move |mut cx| {
if !async_ {
cx.concurrent_state()
.instance_states
Expand Down Expand Up @@ -2017,12 +2071,12 @@ pub(crate) async fn poll_until<'a, T: Send, U>(
.await;

if ready.is_some() {
store = poll_fn(store, move |_, mut store| {
store = poll_fn(store, (None, None), move |_, mut store| {
Ok(handle_ready(store.take().unwrap(), ready.take().unwrap()))
})
.await?;
} else {
let (s, resumed) = poll_fn(store, move |_, mut store| {
let (s, resumed) = poll_fn(store, (None, None), move |_, mut store| {
Ok(unyield(store.take().unwrap()))
})
.await?;
Expand All @@ -2039,7 +2093,7 @@ pub(crate) async fn poll_until<'a, T: Send, U>(
Either::Left((None, future_again)) => break Ok((store, future_again.await)),
Either::Left((Some(ready), future_again)) => {
let mut ready = Some(ready);
store = poll_fn(store, move |_, mut store| {
store = poll_fn(store, (None, None), move |_, mut store| {
Ok(handle_ready(store.take().unwrap(), ready.take().unwrap()))
})
.await?;
Expand All @@ -2052,13 +2106,14 @@ pub(crate) async fn poll_until<'a, T: Send, U>(

async fn poll_fn<'a, T, R>(
mut store: StoreContextMut<'a, T>,
guard_range: (Option<SendSyncPtr<u8>>, Option<SendSyncPtr<u8>>),
mut fun: impl FnMut(
&mut Context,
Option<StoreContextMut<'a, T>>,
) -> Result<R, Option<StoreContextMut<'a, T>>>,
) -> R {
#[derive(Clone, Copy)]
struct PollCx(*mut *mut Context<'static>);
struct PollCx(*mut PollContext);

unsafe impl Send for PollCx {}

Expand All @@ -2068,7 +2123,13 @@ async fn poll_fn<'a, T, R>(

move |cx| unsafe {
let _reset = Reset(poll_cx.0, *poll_cx.0);
*poll_cx.0 = mem::transmute::<&mut Context<'_>, *mut Context<'static>>(cx);
let guard_range_start = guard_range.0.map(|v| v.as_ptr()).unwrap_or(ptr::null_mut());
let guard_range_end = guard_range.1.map(|v| v.as_ptr()).unwrap_or(ptr::null_mut());
*poll_cx.0 = PollContext {
future_context: mem::transmute::<&mut Context<'_>, *mut Context<'static>>(cx),
guard_range_start,
guard_range_end,
};
#[allow(dropping_copy_types)]
drop(poll_cx);

Expand Down
4 changes: 2 additions & 2 deletions crates/wasmtime/src/runtime/component/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ impl Func {
let instance = store.0[self.0].component_instance;
// TODO: do we need to return the store here due to the possible
// invalidation of the reference we were passed?
concurrent::on_fiber(store, instance, move |store| {
concurrent::on_fiber(store, Some(instance), move |store| {
self.call_impl(store, params, results)
})
.await?
Expand Down Expand Up @@ -367,7 +367,7 @@ impl Func {
let instance = store.0[self.0].component_instance;
// TODO: do we need to return the store here due to the possible
// invalidation of the reference we were passed?
concurrent::on_fiber(store, instance, move |store| {
concurrent::on_fiber(store, Some(instance), move |store| {
self.start_call(store.as_context_mut(), params)
})
.await?
Expand Down
10 changes: 6 additions & 4 deletions crates/wasmtime/src/runtime/component/func/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,11 @@ where
let instance = store.0[self.func.0].component_instance;
// TODO: do we need to return the store here due to the possible
// invalidation of the reference we were passed?
concurrent::on_fiber(store, instance, move |store| self.call_impl(store, params))
.await?
.0
concurrent::on_fiber(store, Some(instance), move |store| {
self.call_impl(store, params)
})
.await?
.0
}
#[cfg(not(feature = "component-model-async"))]
{
Expand Down Expand Up @@ -236,7 +238,7 @@ where
let instance = store.0[self.func.0].component_instance;
// TODO: do we need to return the store here due to the possible
// invalidation of the reference we were passed?
concurrent::on_fiber(store, instance, move |store| {
concurrent::on_fiber(store, Some(instance), move |store| {
self.start_call(store.as_context_mut(), params)
})
.await?
Expand Down
16 changes: 14 additions & 2 deletions crates/wasmtime/src/runtime/component/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -864,12 +864,24 @@ impl<T> InstancePre<T> {
where
T: Send + 'static,
{
let mut store = store.as_context_mut();
let store = store.as_context_mut();
assert!(
store.0.async_support(),
"must use sync instantiation when async support is disabled"
);
store.on_fiber(|store| self.instantiate_impl(store)).await?
#[cfg(feature = "component-model-async")]
{
// TODO: do we need to return the store here due to the possible
// invalidation of the reference we were passed?
concurrent::on_fiber(store, None, move |store| self.instantiate_impl(store))
.await?
.0
}
#[cfg(not(feature = "component-model-async"))]
{
let mut store = store;
store.on_fiber(|store| self.instantiate_impl(store)).await?
}
}

fn instantiate_impl(&self, mut store: impl AsContextMut<Data = T>) -> Result<Instance>
Expand Down
Loading

0 comments on commit d35d87c

Please sign in to comment.