tokio_xmpp/xmlstream/
capture.rs

1// Copyright (c) 2024 Jonas Schäfer <jonas@zombofant.net>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7//! Small helper struct to capture data read from an AsyncBufRead.
8
9use core::pin::Pin;
10use core::task::{Context, Poll};
11use std::io::{self, IoSlice};
12
13use futures::ready;
14
15use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
16
17use super::LogXsoBuf;
18
19pin_project_lite::pin_project! {
20    /// Wrapper around [`AsyncBufRead`] which stores bytes which have been
21    /// read in an internal vector for later inspection.
22    ///
23    /// This struct implements [`AsyncRead`] and [`AsyncBufRead`] and passes
24    /// read requests down to the wrapped [`AsyncBufRead`].
25    ///
26    /// After capturing has been enabled using [`Self::enable_capture`], any
27    /// data which is read via the struct will be stored in an internal buffer
28    /// and can be extracted with [`Self::take_capture`] or discarded using
29    /// [`Self::discard_capture`].
30    ///
31    /// This can be used to log data which is being read from a source.
32    ///
33    /// In addition, this struct implements [`AsyncWrite`] if and only if `T`
34    /// implements [`AsyncWrite`]. Writing is unaffected by capturing and is
35    /// implemented solely for convenience purposes (to allow duplex usage
36    /// of a wrapped I/O object).
37    pub(super) struct CaptureBufRead<T> {
38        #[pin]
39        inner: T,
40        buf: Option<(Vec<u8>, usize)>,
41    }
42}
43
44impl<T> CaptureBufRead<T> {
45    /// Wrap a given [`AsyncBufRead`].
46    ///
47    /// Note that capturing of data which is being read is disabled by default
48    /// and needs to be enabled using [`Self::enable_capture`].
49    pub fn wrap(inner: T) -> Self {
50        Self { inner, buf: None }
51    }
52
53    /// Extract the inner [`AsyncBufRead`] and discard the capture buffer.
54    pub fn into_inner(self) -> T {
55        self.inner
56    }
57
58    /// Obtain a reference to the inner [`AsyncBufRead`].
59    pub fn inner(&self) -> &T {
60        &self.inner
61    }
62
63    /// Enable capturing of read data into the inner buffer.
64    ///
65    /// Any data which is read from now on will be copied into the internal
66    /// buffer. That buffer will grow indefinitely until calls to
67    /// [`Self::take_capture`] or [`Self::discard_capture`].
68    pub fn enable_capture(&mut self) {
69        self.buf = Some((Vec::new(), 0));
70    }
71
72    /// Discard the current buffer data, if any.
73    ///
74    /// Further data which is read will be captured again.
75    pub(super) fn discard_capture(self: Pin<&mut Self>) {
76        let this = self.project();
77        if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
78            buf.drain(..*consumed_up_to);
79            *consumed_up_to = 0;
80        }
81    }
82
83    /// Take the currently captured data out of the inner buffer.
84    ///
85    /// Returns `None` unless capturing has been enabled using
86    /// [`Self::enable_capture`].
87    pub(super) fn take_capture(self: Pin<&mut Self>) -> Option<Vec<u8>> {
88        let this = self.project();
89        let (buf, consumed_up_to) = this.buf.as_mut()?;
90        let result = buf.drain(..*consumed_up_to).collect();
91        *consumed_up_to = 0;
92        Some(result)
93    }
94}
95
96impl<T: AsyncRead> AsyncRead for CaptureBufRead<T> {
97    fn poll_read(
98        self: Pin<&mut Self>,
99        cx: &mut Context,
100        read_buf: &mut ReadBuf,
101    ) -> Poll<io::Result<()>> {
102        let this = self.project();
103        let prev_len = read_buf.filled().len();
104        let result = ready!(this.inner.poll_read(cx, read_buf));
105        if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
106            buf.truncate(*consumed_up_to);
107            buf.extend(&read_buf.filled()[prev_len..]);
108            *consumed_up_to = buf.len();
109        }
110        Poll::Ready(result)
111    }
112}
113
114impl<T: AsyncBufRead> AsyncBufRead for CaptureBufRead<T> {
115    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
116        let this = self.project();
117        let result = ready!(this.inner.poll_fill_buf(cx))?;
118        if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
119            buf.truncate(*consumed_up_to);
120            buf.extend(result);
121        }
122        Poll::Ready(Ok(result))
123    }
124
125    fn consume(self: Pin<&mut Self>, amt: usize) {
126        let this = self.project();
127        this.inner.consume(amt);
128        if let Some((_, consumed_up_to)) = this.buf.as_mut() {
129            // Increase the amount of data to preserve.
130            *consumed_up_to = *consumed_up_to + amt;
131        }
132    }
133}
134
135impl<T: AsyncWrite> AsyncWrite for CaptureBufRead<T> {
136    fn poll_write(
137        self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        buf: &[u8],
140    ) -> Poll<io::Result<usize>> {
141        self.project().inner.poll_write(cx, buf)
142    }
143
144    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145        self.project().inner.poll_shutdown(cx)
146    }
147
148    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
149        self.project().inner.poll_flush(cx)
150    }
151
152    fn is_write_vectored(&self) -> bool {
153        self.inner.is_write_vectored()
154    }
155
156    fn poll_write_vectored(
157        self: Pin<&mut Self>,
158        cx: &mut Context,
159        bufs: &[IoSlice],
160    ) -> Poll<io::Result<usize>> {
161        self.project().inner.poll_write_vectored(cx, bufs)
162    }
163}
164
165/// Return true if logging via [`log_recv`] or [`log_send`] might be visible
166/// to the user.
167pub(super) fn log_enabled() -> bool {
168    log::log_enabled!(log::Level::Trace)
169}
170
171/// Log received data.
172///
173/// `err` is an error which may be logged alongside the received data.
174/// `capture` is the data which has been received and which should be logged.
175/// If built with the `syntax-highlighting` feature, `capture` data will be
176/// logged with XML syntax highlighting.
177///
178/// If both `err` and `capture` are None, nothing will be logged.
179pub(super) fn log_recv(err: Option<&xmpp_parsers::Error>, capture: Option<Vec<u8>>) {
180    match err {
181        Some(err) => match capture {
182            Some(capture) => {
183                log::trace!("RECV (error: {}) {}", err, LogXsoBuf(&capture));
184            }
185            None => {
186                log::trace!("RECV (error: {}) [data capture disabled]", err);
187            }
188        },
189        None => match capture {
190            Some(capture) => {
191                log::trace!("RECV (ok) {}", LogXsoBuf(&capture));
192            }
193            None => (),
194        },
195    }
196}
197
198/// Log sent data.
199///
200/// If built with the `syntax-highlighting` feature, `data` data will be
201/// logged with XML syntax highlighting.
202pub(super) fn log_send(data: &[u8]) {
203    log::trace!("SEND {}", LogXsoBuf(data));
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    use tokio::io::{AsyncBufReadExt, AsyncReadExt};
211
212    #[tokio::test]
213    async fn captures_data_read_via_async_read() {
214        let mut src = &b"Hello World!"[..];
215        let src = tokio::io::BufReader::new(&mut src);
216        let mut src = CaptureBufRead::wrap(src);
217        src.enable_capture();
218
219        let mut dst = [0u8; 8];
220        assert_eq!(src.read(&mut dst[..]).await.unwrap(), 8);
221        assert_eq!(&dst, b"Hello Wo");
222        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello Wo");
223    }
224
225    #[tokio::test]
226    async fn captures_data_read_via_async_buf_read() {
227        let mut src = &b"Hello World!"[..];
228        let src = tokio::io::BufReader::new(&mut src);
229        let mut src = CaptureBufRead::wrap(src);
230        src.enable_capture();
231
232        assert_eq!(src.fill_buf().await.unwrap().len(), 12);
233        // We haven't consumed any bytes yet -> must return zero.
234        assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
235
236        src.consume(5);
237        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello");
238
239        src.consume(6);
240        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b" World");
241    }
242
243    #[tokio::test]
244    async fn discard_capture_drops_consumed_data() {
245        let mut src = &b"Hello World!"[..];
246        let src = tokio::io::BufReader::new(&mut src);
247        let mut src = CaptureBufRead::wrap(src);
248        src.enable_capture();
249
250        assert_eq!(src.fill_buf().await.unwrap().len(), 12);
251        // We haven't consumed any bytes yet -> must return zero.
252        assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
253
254        src.consume(5);
255        Pin::new(&mut src).discard_capture();
256
257        src.consume(6);
258        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b" World");
259    }
260
261    #[tokio::test]
262    async fn captured_data_accumulates() {
263        let mut src = &b"Hello World!"[..];
264        let src = tokio::io::BufReader::new(&mut src);
265        let mut src = CaptureBufRead::wrap(src);
266        src.enable_capture();
267
268        assert_eq!(src.fill_buf().await.unwrap().len(), 12);
269        // We haven't consumed any bytes yet -> must return zero.
270        assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
271
272        src.consume(5);
273        src.consume(6);
274        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello World");
275    }
276}