use std::{
fmt,
io::ErrorKind,
pin::Pin,
sync::Arc,
task::{self, Poll},
time::Duration,
};
use crate::{McWaker, Progress};
#[doc(no_inline)]
pub use futures_lite::io::{
copy, empty, repeat, sink, split, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite,
AsyncWriteExt, BoxedReader, BoxedWriter, BufReader, BufWriter, Cursor, ReadHalf, WriteHalf,
};
use parking_lot::Mutex;
use std::io::{Error, Result};
use zng_time::{DInstant, INSTANT};
use zng_txt::formatx;
use zng_unit::{ByteLength, ByteUnits};
use zng_var::impl_from_and_into_var;
pub struct Measure<T> {
task: T,
metrics: Metrics,
start_time: DInstant,
last_write: DInstant,
last_read: DInstant,
}
impl<T> Measure<T> {
pub fn start(task: T, total_read: impl Into<ByteLength>, total_write: impl Into<ByteLength>) -> Self {
Self::resume(task, (0, total_read), (0, total_write))
}
pub fn resume(
task: T,
read_progress: (impl Into<ByteLength>, impl Into<ByteLength>),
write_progress: (impl Into<ByteLength>, impl Into<ByteLength>),
) -> Self {
let now = INSTANT.now();
Measure {
task,
metrics: Metrics {
read_progress: (read_progress.0.into(), read_progress.1.into()),
read_speed: 0.bytes(),
write_progress: (write_progress.0.into(), write_progress.1.into()),
write_speed: 0.bytes(),
total_time: Duration::ZERO,
},
start_time: now,
last_write: now,
last_read: now,
}
}
pub fn metrics(&mut self) -> &Metrics {
&self.metrics
}
pub fn finish(mut self) -> (T, Metrics) {
self.metrics.total_time = self.start_time.elapsed();
(self.task, self.metrics)
}
}
fn bytes_per_sec(bytes: ByteLength, elapsed: Duration) -> ByteLength {
let bytes_per_sec = bytes.0 as u128 / elapsed.as_nanos() / Duration::from_secs(1).as_nanos();
ByteLength(bytes_per_sec as usize)
}
impl<T: AsyncRead + Unpin> AsyncRead for Measure<T> {
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
let self_ = self.get_mut();
match Pin::new(&mut self_.task).poll_read(cx, buf) {
Poll::Ready(Ok(bytes)) => {
if bytes > 0 {
let bytes = bytes.bytes();
self_.metrics.read_progress.0 += bytes;
let now = INSTANT.now();
let elapsed = now - self_.last_read;
self_.last_read = now;
self_.metrics.read_speed = bytes_per_sec(bytes, elapsed);
self_.metrics.total_time = now - self_.start_time;
}
Poll::Ready(Ok(bytes))
}
p => p,
}
}
}
impl<T: AsyncWrite + Unpin> AsyncWrite for Measure<T> {
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let self_ = self.get_mut();
match Pin::new(&mut self_.task).poll_write(cx, buf) {
Poll::Ready(Ok(bytes)) => {
if bytes > 0 {
let bytes = bytes.bytes();
self_.metrics.write_progress.0 += bytes;
let now = INSTANT.now();
let elapsed = now - self_.last_write;
self_.last_write = now;
self_.metrics.write_speed = bytes_per_sec(bytes, elapsed);
self_.metrics.total_time = now - self_.start_time;
}
Poll::Ready(Ok(bytes))
}
p => p,
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.get_mut().task).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.get_mut().task).poll_close(cx)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Metrics {
pub read_progress: (ByteLength, ByteLength),
pub read_speed: ByteLength,
pub write_progress: (ByteLength, ByteLength),
pub write_speed: ByteLength,
pub total_time: Duration,
}
impl Metrics {
pub fn zero() -> Self {
Self {
read_progress: (0.bytes(), 0.bytes()),
read_speed: 0.bytes(),
write_progress: (0.bytes(), 0.bytes()),
write_speed: 0.bytes(),
total_time: Duration::ZERO,
}
}
}
impl fmt::Display for Metrics {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut w = false;
if self.read_progress.1 > 0.bytes() {
w = true;
if self.read_progress.0 != self.read_progress.1 {
write!(f, "↓ {}-{}, {}/s", self.read_progress.0, self.read_progress.1, self.read_speed)?;
w = true;
} else {
write!(f, "↓ {} . {:?}", self.read_progress.0, self.total_time)?;
}
}
if self.write_progress.1 > 0.bytes() {
if w {
writeln!(f)?;
}
if self.write_progress.0 != self.write_progress.1 {
write!(f, "↑ {} - {}, {}/s", self.write_progress.0, self.write_progress.1, self.write_speed)?;
} else {
write!(f, "↑ {} . {:?}", self.write_progress.0, self.total_time)?;
}
}
Ok(())
}
}
impl_from_and_into_var! {
fn from(metrics: Metrics) -> Progress {
let mut status = Progress::indeterminate();
if metrics.read_progress.1 > 0.bytes() {
status = Progress::from_n_of(metrics.read_progress.0 .0, metrics.read_progress.1 .0);
}
if metrics.write_progress.1 > 0.bytes() {
let w_status = Progress::from_n_of(metrics.write_progress.0 .0, metrics.write_progress.1 .0);
if status.is_indeterminate() {
status = w_status;
} else {
status = status.and_fct(w_status.fct());
}
}
status.with_msg(formatx!("{metrics}")).with_meta_mut(|mut m| {
m.set(*METRICS_ID, metrics);
})
}
}
zng_state_map::static_id! {
pub static ref METRICS_ID: zng_state_map::StateId<Metrics>;
}
pub trait McBufErrorExt {
fn is_only_lazy_left(&self) -> bool;
}
impl McBufErrorExt for std::io::Error {
fn is_only_lazy_left(&self) -> bool {
matches!(self.kind(), ErrorKind::Other) && format!("{self:?}").contains(ONLY_NON_LAZY_ERROR_MSG)
}
}
const ONLY_NON_LAZY_ERROR_MSG: &str = "no non-lazy readers left to read";
pub struct McBufReader<S: AsyncRead> {
inner: Arc<Mutex<McBufInner<S>>>,
index: usize,
lazy: bool,
}
struct McBufInner<S: AsyncRead> {
source: Option<S>,
waker: McWaker,
lazy_wakers: Vec<task::Waker>,
buf: Vec<u8>,
clones: Vec<usize>,
non_lazy_count: usize,
result: ReadState,
}
impl<S: AsyncRead> McBufReader<S> {
pub fn new(source: S) -> Self {
let mut clones = Vec::with_capacity(2);
clones.push(0);
McBufReader {
inner: Arc::new(Mutex::new(McBufInner {
source: Some(source),
waker: McWaker::empty(),
lazy_wakers: vec![],
buf: Vec::with_capacity(10.kilobytes().0),
clones,
non_lazy_count: 1,
result: ReadState::Running,
})),
index: 0,
lazy: false,
}
}
pub fn is_lazy(&self) -> bool {
self.lazy
}
pub fn set_lazy(&mut self, lazy: bool) {
if self.lazy != lazy {
if lazy {
self.inner.lock().non_lazy_count -= 1;
} else {
self.inner.lock().non_lazy_count += 1;
}
self.lazy = lazy;
}
}
}
impl<S: AsyncRead> Clone for McBufReader<S> {
fn clone(&self) -> Self {
let mut inner = self.inner.lock();
let offset = inner.clones[self.index];
let index = inner.clones.len();
inner.clones.push(offset);
if !self.lazy {
inner.non_lazy_count += 1;
}
Self {
inner: self.inner.clone(),
index,
lazy: self.lazy,
}
}
}
impl<S: AsyncRead> Drop for McBufReader<S> {
fn drop(&mut self) {
let mut inner = self.inner.lock();
inner.clones[self.index] = usize::MAX;
if !self.lazy {
inner.non_lazy_count -= 1;
if inner.non_lazy_count == 0 {
for waker in inner.lazy_wakers.drain(..) {
waker.wake();
}
}
}
}
}
impl<S: AsyncRead> AsyncRead for McBufReader<S> {
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
let self_ = self.as_ref();
let mut inner = self_.inner.lock();
let inner = &mut *inner;
let mut i = inner.clones[self_.index];
let mut ready;
match &inner.result {
ReadState::Running => {
ready = &inner.buf[i..];
if ready.is_empty() {
if self.lazy {
if inner.non_lazy_count == 0 {
return Poll::Ready(Err(Error::new(ErrorKind::Other, ONLY_NON_LAZY_ERROR_MSG)));
} else {
inner.lazy_wakers.push(cx.waker().clone());
return Poll::Pending;
}
}
ready = &[];
let waker = match inner.waker.push(cx.waker().clone()) {
Some(w) => w,
None => {
return Poll::Pending;
}
};
let min_i = inner.clones.iter().copied().min().unwrap();
if min_i > 0 {
inner.buf.copy_within(min_i.., 0);
inner.buf.truncate(inner.buf.len() - min_i);
i -= min_i;
for i in &mut inner.clones {
*i -= min_i;
}
}
let new_start = inner.buf.len();
inner.buf.resize(inner.buf.len() + buf.len().max(10.kilobytes().0), 0);
let mut inner_cx = task::Context::from_waker(&waker);
let source = unsafe { Pin::new_unchecked(inner.source.as_mut().unwrap()) };
let result = source.poll_read(&mut inner_cx, &mut inner.buf[new_start..]);
match result {
Poll::Ready(result) => {
for waker in inner.lazy_wakers.drain(..) {
waker.wake();
}
match result {
Ok(0) => {
inner.waker.cancel();
inner.buf.truncate(new_start);
inner.result = ReadState::Eof;
inner.source = None;
}
Ok(read) => {
inner.waker.cancel();
inner.buf.truncate(new_start + read);
ready = &inner.buf[i..];
}
Err(e) => {
inner.waker.cancel();
inner.result = ReadState::Err(CloneableError::new(&e));
inner.buf = vec![];
inner.source = None;
return Poll::Ready(Err(e));
}
}
}
Poll::Pending => {
inner.buf.truncate(new_start);
return Poll::Pending;
}
}
}
}
ReadState::Eof => {
ready = &inner.buf[i..];
}
ReadState::Err(e) => return Poll::Ready(e.err()),
}
let max_ready = buf.len().min(ready.len());
buf[..max_ready].copy_from_slice(&ready[..max_ready]);
i += max_ready;
inner.clones[self_.index] = i;
Poll::Ready(Ok(max_ready))
}
}
#[derive(Clone)]
pub struct CloneableError {
info: ErrorInfo,
}
#[derive(Clone)]
enum ErrorInfo {
OsError(i32),
Other(ErrorKind, String),
}
impl CloneableError {
pub fn new(e: &Error) -> Self {
let info = if let Some(code) = e.raw_os_error() {
ErrorInfo::OsError(code)
} else {
ErrorInfo::Other(e.kind(), format!("{e}"))
};
Self { info }
}
pub fn err<T>(&self) -> Result<T> {
Err(self.clone().into())
}
}
impl From<CloneableError> for Error {
fn from(e: CloneableError) -> Self {
match e.info {
ErrorInfo::OsError(code) => Error::from_raw_os_error(code),
ErrorInfo::Other(kind, msg) => Error::new(kind, msg),
}
}
}
pub struct ReadLimited<S, L> {
source: S,
limit: usize,
on_limit: L,
}
impl<S, L> ReadLimited<S, L>
where
S: AsyncRead,
L: Fn() -> std::io::Error,
{
pub fn new(source: S, limit: ByteLength, on_limit: L) -> Self {
Self {
source,
limit: limit.0,
on_limit,
}
}
}
impl<S, L> AsyncRead for ReadLimited<S, L>
where
S: AsyncRead,
L: Fn() -> std::io::Error,
{
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, mut buf: &mut [u8]) -> Poll<Result<usize>> {
let self_ = unsafe { self.get_unchecked_mut() };
if self_.limit == 0 {
let err = (self_.on_limit)();
return Poll::Ready(Err(err));
}
if buf.len() > self_.limit {
buf = &mut buf[..self_.limit];
}
match unsafe { Pin::new_unchecked(&mut self_.source) }.poll_read(cx, buf) {
Poll::Ready(Ok(l)) => {
self_.limit = self_.limit.saturating_sub(l);
if self_.limit == 0 {
let err = (self_.on_limit)();
Poll::Ready(Err(err))
} else {
Poll::Ready(Ok(l))
}
}
r => r,
}
}
}
enum ReadState {
Running,
Eof,
Err(CloneableError),
}
#[cfg(test)]
mod tests {
use super::*;
use crate as task;
use zng_unit::TimeUnits;
#[test]
pub fn mc_buf_reader_parallel() {
let data = Data::new(60.kilobytes().0);
let mut expected = vec![0; data.len];
let _ = data.clone().blocking_read(&mut expected[..]);
let mut a = McBufReader::new(data);
let mut b = a.clone();
let mut c = a.clone();
let (a, b, c) = async_test(async move {
let a = task::run(async move {
let mut buf = vec![];
a.read_to_end(&mut buf).await.unwrap();
buf
});
let b = task::run(async move {
let mut buf: Vec<u8> = vec![];
b.read_to_end(&mut buf).await.unwrap();
buf
});
let c = task::run(async move {
let mut buf: Vec<u8> = vec![];
c.read_to_end(&mut buf).await.unwrap();
buf
});
task::all!(a, b, c).await
});
crate::assert_vec_eq!(expected, a);
crate::assert_vec_eq!(expected, b);
crate::assert_vec_eq!(expected, c);
}
#[test]
pub fn mc_buf_reader_single() {
let data = Data::new(60.kilobytes().0);
let mut expected = vec![0; data.len];
let _ = data.clone().blocking_read(&mut expected[..]);
let mut a = McBufReader::new(data);
let a = async_test(async move {
let a = task::run(async move {
let mut buf = vec![];
a.read_to_end(&mut buf).await.unwrap();
buf
});
a.await
});
crate::assert_vec_eq!(expected, a);
}
#[test]
pub fn mc_buf_reader_sequential() {
let data = Data::new(60.kilobytes().0);
let mut expected = vec![0; data.len];
let _ = data.clone().blocking_read(&mut expected[..]);
let mut clones = vec![McBufReader::new(data)];
for _ in 0..5 {
clones.push(clones[0].clone());
}
let r = async_test(async move {
let mut r = vec![];
for mut clone in clones {
let mut buf = vec![];
clone.read_to_end(&mut buf).await.unwrap();
r.push(buf);
}
r
});
for r in r {
crate::assert_vec_eq!(expected, r);
}
}
#[test]
pub fn mc_buf_reader_completed() {
let data = Data::new(60.kilobytes().0);
let mut buf = Vec::with_capacity(data.len);
let mut a = McBufReader::new(data);
let r = async_test(async move {
a.read_to_end(&mut buf).await.unwrap();
let mut b = a.clone();
buf.clear();
b.read_to_end(&mut buf).await.unwrap();
buf.len()
});
assert_eq!(0, r);
}
#[test]
pub fn mc_buf_reader_error() {
let mut data = Data::new(20.kilobytes().0);
data.set_error();
let mut expected = vec![0; data.len];
let _ = data.clone().blocking_read(&mut expected[..]);
let mut a = McBufReader::new(data);
let mut b = a.clone();
let (a, b) = async_test(async move {
let a = task::run(async move {
let mut buf = vec![];
a.read_to_end(&mut buf).await.unwrap_err()
});
let b = task::run(async move {
let mut buf: Vec<u8> = vec![];
b.read_to_end(&mut buf).await.unwrap_err()
});
task::all!(a, b).await
});
assert_eq!(ErrorKind::InvalidData, a.kind());
assert_eq!(ErrorKind::InvalidData, b.kind());
}
#[test]
pub fn mc_buf_reader_error_completed() {
let mut data = Data::new(20.kilobytes().0);
data.set_error();
let mut buf = Vec::with_capacity(data.len);
let mut a = McBufReader::new(data);
let (a, b) = async_test(async move {
let a_err = a.read_to_end(&mut buf).await.unwrap_err();
let mut b = a.clone();
buf.clear();
let b_err = b.read_to_end(&mut buf).await.unwrap_err();
(a_err, b_err)
});
assert_eq!(ErrorKind::InvalidData, a.kind());
assert_eq!(ErrorKind::InvalidData, b.kind());
}
#[test]
pub fn mc_buf_reader_parallel_with_delay1() {
let mut data = Data::new(60.kilobytes().0);
data.enable_pending();
let mut expected = vec![0; data.len];
let _ = data.clone().blocking_read(&mut expected[..]);
let mut a = McBufReader::new(data);
let mut b = a.clone();
let mut c = a.clone();
let (a, b, c) = async_test(async move {
let a = task::run(async move {
let mut buf = vec![];
a.read_to_end(&mut buf).await.unwrap();
buf
});
let b = task::run(async move {
let mut buf: Vec<u8> = vec![];
b.read_to_end(&mut buf).await.unwrap();
buf
});
let c = task::run(async move {
let mut buf: Vec<u8> = vec![];
c.read_to_end(&mut buf).await.unwrap();
buf
});
task::all!(a, b, c).await
});
crate::assert_vec_eq!(expected, a);
crate::assert_vec_eq!(expected, b);
crate::assert_vec_eq!(expected, c);
}
#[test]
pub fn mc_buf_reader_parallel_with_delay2() {
let mut data = Data::new(60.kilobytes().0);
data.enable_pending();
let mut expected = vec![0; data.len];
let _ = data.clone().blocking_read(&mut expected[..]);
let mut a = McBufReader::new(data);
let mut b = a.clone();
let mut c = a.clone();
let (a, b, c) = async_test(async move {
let a = task::run(async move {
let mut buf = vec![];
a.read_to_end(&mut buf).await.unwrap();
buf
});
let b = task::run(async move {
let mut buf: Vec<u8> = vec![];
task::deadline(5.ms()).await;
b.read_to_end(&mut buf).await.unwrap();
buf
});
let c = task::run(async move {
let mut buf: Vec<u8> = vec![];
c.read_to_end(&mut buf).await.unwrap();
buf
});
task::all!(a, b, c).await
});
crate::assert_vec_eq!(expected, a);
crate::assert_vec_eq!(expected, b);
crate::assert_vec_eq!(expected, c);
}
#[derive(Clone)]
struct Data {
b: u8,
len: usize,
error: Option<CloneableError>,
delay: Duration,
pending: bool,
}
impl Data {
pub fn new(len: usize) -> Self {
Self {
b: 0,
len,
error: None,
delay: 0.ms(),
pending: false,
}
}
pub fn blocking_read(&mut self, buf: &mut [u8]) -> Result<usize> {
let len = self.len;
for b in buf.iter_mut().take(len) {
*b = self.b;
self.len -= 1;
self.b = self.b.wrapping_add(1);
}
if len == 0 {
if let Some(e) = &self.error {
return e.err();
}
}
Ok(buf.len().min(len))
}
pub fn set_error(&mut self) {
self.error = Some(CloneableError::new(&Error::new(ErrorKind::InvalidData, "test error")));
}
pub fn enable_pending(&mut self) {
self.delay = 3.ms();
}
}
impl AsyncRead for Data {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
if self.delay > Duration::ZERO {
self.pending = !self.pending;
if self.pending {
let waker = cx.waker().clone();
let delay = self.delay;
task::spawn(async move {
task::deadline(delay).await;
waker.wake();
});
return Poll::Pending;
}
}
let r = self.as_mut().blocking_read(buf);
Poll::Ready(r)
}
}
#[track_caller]
fn async_test<F>(test: F) -> F::Output
where
F: std::future::Future,
{
task::block_on(task::with_deadline(test, 5.secs())).unwrap()
}
#[macro_export]
macro_rules! assert_vec_eq {
($a:expr, $b: expr) => {
match (&$a, &$b) {
(ref a, ref b) => {
let len_not_eq = a.len() != b.len();
let mut data_not_eq = None;
for (i, (a, b)) in a.iter().zip(b.iter()).enumerate() {
if a != b {
data_not_eq = Some(i);
break;
}
}
if len_not_eq || data_not_eq.is_some() {
use std::fmt::*;
let mut error = format!("`{}` != `{}`", stringify!($a), stringify!($b));
if len_not_eq {
let _ = write!(&mut error, "\n lengths not equal: {} != {}", a.len(), b.len());
}
if let Some(i) = data_not_eq {
let _ = write!(&mut error, "\n data not equal at index {}: {} != {:?}", i, a[i], b[i]);
}
panic!("{error}")
}
}
}
};
}
}