This commit is contained in:
nora 2024-06-27 19:44:59 +02:00
parent 9bab547bcf
commit 8c59c7b3ae
28 changed files with 233 additions and 1306 deletions

7
async-experiments/Cargo.lock generated Normal file
View file

@ -0,0 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "async-experiments"
version = "0.1.0"

View file

@ -0,0 +1,6 @@
[package]
name = "async-experiments"
version = "0.1.0"
edition = "2021"
[dependencies]

View file

@ -0,0 +1,38 @@
use std::{
future::Future,
pin::pin,
sync::Arc,
task::{Context, Poll, Wake, Waker},
};
pub struct Executor {}
impl Executor {
pub fn new() -> Self {
Executor {}
}
pub fn block_on<F: Future>(&self, fut: F) -> F::Output {
let mut fut = pin!(fut);
let this_thread = std::thread::current();
let waker = Waker::from(Arc::new(WakeFn(move || {
this_thread.unpark();
})));
let mut ctx = Context::from_waker(&waker);
loop {
let result = fut.as_mut().poll(&mut ctx);
match result {
Poll::Ready(output) => return output,
Poll::Pending => std::thread::park(),
}
}
}
}
struct WakeFn<F>(F);
impl<F: Fn()> Wake for WakeFn<F> {
fn wake(self: std::sync::Arc<Self>) {
(self.0)()
}
}

View file

@ -0,0 +1,75 @@
use std::{
fmt::Debug,
future::Future,
pin::Pin,
task::{Context, Poll},
};
pub fn join2<F1, F2>(fut1: F1, fut2: F2) -> Join2<F1, F2>
where
F1: Future,
F2: Future,
{
Join2(JoinState::Pending(fut1), JoinState::Pending(fut2))
}
pub struct Join2<F1: Future, F2: Future>(JoinState<F1>, JoinState<F2>);
#[derive(Debug)]
enum JoinState<F: Future> {
Pending(F),
Ready(F::Output),
Stolen,
}
impl<F: Future> JoinState<F> {
fn steal(&mut self) -> F::Output {
match std::mem::replace(self, JoinState::Stolen) {
JoinState::Ready(output) => output,
_ => unreachable!("tried to take output of non-ready join state"),
}
}
}
impl<F1: Future, F2: Future> Future for Join2<F1, F2> {
type Output = (F1::Output, F2::Output);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
fn make_progress<F: Future>(field: &mut JoinState<F>, cx: &mut Context<'_>) {
match field {
JoinState::Pending(fut) => match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
Poll::Ready(result) => {
*field = JoinState::Ready(result);
}
Poll::Pending => {}
},
JoinState::Ready(_) => {}
JoinState::Stolen => unreachable!("future polled after completion"),
}
}
make_progress(&mut this.0, cx);
make_progress(&mut this.1, cx);
if let (JoinState::Ready(_), JoinState::Ready(_)) = (&this.0, &this.1) {
return Poll::Ready((this.0.steal(), this.1.steal()));
}
Poll::Pending
}
}
impl<F1: Future + Debug, F2: Future + Debug> Debug for Join2<F1, F2>
where
F1::Output: Debug,
F2::Output: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Join2")
.field(&self.0)
.field(&self.1)
.finish()
}
}

View file

@ -0,0 +1,7 @@
mod executor;
mod spawn_blocking;
mod join;
pub use executor::*;
pub use spawn_blocking::*;
pub use join::*;

View file

@ -0,0 +1,76 @@
use std::{
fmt::Debug,
future::Future,
sync::{Arc, Mutex},
task::{Poll, Waker},
};
#[derive(Debug)]
pub struct JoinHandle<T> {
inner: Arc<Inner<T>>,
}
struct Inner<T> {
result: Mutex<Option<T>>,
waker: Mutex<Option<Waker>>,
}
pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
where
R: Send + 'static,
F: Send + FnOnce() -> R + 'static,
{
let inner = Arc::new(Inner {
result: Mutex::new(None),
waker: Mutex::new(None),
});
let inner2 = inner.clone();
std::thread::spawn(move || {
let result = f();
*inner2.result.lock().unwrap() = Some(result);
if let Some(waker) = inner2.waker.lock().unwrap().take() {
waker.wake();
}
});
JoinHandle { inner }
}
impl<T> Future for JoinHandle<T> {
type Output = T;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut result = self.inner.result.lock().unwrap();
match result.take() {
Some(result) => Poll::Ready(result),
None => {
*self.inner.waker.lock().unwrap() = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
impl<T: Debug> Debug for Inner<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Inner")
.field("result", &self.result)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use crate::Executor;
#[test]
fn spawn_value() {
let executor = Executor::new();
let result = executor.block_on(super::spawn_blocking(|| 1 + 1));
assert_eq!(result, 2);
}
}

View file

@ -0,0 +1,23 @@
use async_experiments::Executor;
#[test]
fn execute() {
let executor = Executor::new();
executor.block_on(async {});
executor.block_on(async {});
}
#[test]
fn join2() {
let exec = Executor::new();
let r = exec.block_on(async {
let t1 = async_experiments::spawn_blocking(|| 1);
let t2 = async_experiments::spawn_blocking(|| 2);
let (r1, r2) = async_experiments::join2(t1, t2).await;
r1 + r2
});
assert_eq!(r, 3)
}