tokio_xmpp/xmlstream/
capture.rs1use core::pin::Pin;
10use core::task::{Context, Poll};
11use std::io::{self, IoSlice};
12
13use futures::ready;
14
15use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
16
17use super::LogXsoBuf;
18
19pin_project_lite::pin_project! {
20 pub(super) struct CaptureBufRead<T> {
38 #[pin]
39 inner: T,
40 buf: Option<(Vec<u8>, usize)>,
41 }
42}
43
44impl<T> CaptureBufRead<T> {
45 pub fn wrap(inner: T) -> Self {
50 Self { inner, buf: None }
51 }
52
53 pub fn into_inner(self) -> T {
55 self.inner
56 }
57
58 pub fn inner(&self) -> &T {
60 &self.inner
61 }
62
63 pub fn enable_capture(&mut self) {
69 self.buf = Some((Vec::new(), 0));
70 }
71
72 pub(super) fn discard_capture(self: Pin<&mut Self>) {
76 let this = self.project();
77 if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
78 buf.drain(..*consumed_up_to);
79 *consumed_up_to = 0;
80 }
81 }
82
83 pub(super) fn take_capture(self: Pin<&mut Self>) -> Option<Vec<u8>> {
88 let this = self.project();
89 let (buf, consumed_up_to) = this.buf.as_mut()?;
90 let result = buf.drain(..*consumed_up_to).collect();
91 *consumed_up_to = 0;
92 Some(result)
93 }
94}
95
96impl<T: AsyncRead> AsyncRead for CaptureBufRead<T> {
97 fn poll_read(
98 self: Pin<&mut Self>,
99 cx: &mut Context,
100 read_buf: &mut ReadBuf,
101 ) -> Poll<io::Result<()>> {
102 let this = self.project();
103 let prev_len = read_buf.filled().len();
104 let result = ready!(this.inner.poll_read(cx, read_buf));
105 if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
106 buf.truncate(*consumed_up_to);
107 buf.extend(&read_buf.filled()[prev_len..]);
108 *consumed_up_to = buf.len();
109 }
110 Poll::Ready(result)
111 }
112}
113
114impl<T: AsyncBufRead> AsyncBufRead for CaptureBufRead<T> {
115 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
116 let this = self.project();
117 let result = ready!(this.inner.poll_fill_buf(cx))?;
118 if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
119 buf.truncate(*consumed_up_to);
120 buf.extend(result);
121 }
122 Poll::Ready(Ok(result))
123 }
124
125 fn consume(self: Pin<&mut Self>, amt: usize) {
126 let this = self.project();
127 this.inner.consume(amt);
128 if let Some((_, consumed_up_to)) = this.buf.as_mut() {
129 *consumed_up_to = *consumed_up_to + amt;
131 }
132 }
133}
134
135impl<T: AsyncWrite> AsyncWrite for CaptureBufRead<T> {
136 fn poll_write(
137 self: Pin<&mut Self>,
138 cx: &mut Context<'_>,
139 buf: &[u8],
140 ) -> Poll<io::Result<usize>> {
141 self.project().inner.poll_write(cx, buf)
142 }
143
144 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145 self.project().inner.poll_shutdown(cx)
146 }
147
148 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
149 self.project().inner.poll_flush(cx)
150 }
151
152 fn is_write_vectored(&self) -> bool {
153 self.inner.is_write_vectored()
154 }
155
156 fn poll_write_vectored(
157 self: Pin<&mut Self>,
158 cx: &mut Context,
159 bufs: &[IoSlice],
160 ) -> Poll<io::Result<usize>> {
161 self.project().inner.poll_write_vectored(cx, bufs)
162 }
163}
164
165pub(super) fn log_enabled() -> bool {
168 log::log_enabled!(log::Level::Trace)
169}
170
171pub(super) fn log_recv(err: Option<&xmpp_parsers::Error>, capture: Option<Vec<u8>>) {
180 match err {
181 Some(err) => match capture {
182 Some(capture) => {
183 log::trace!("RECV (error: {}) {}", err, LogXsoBuf(&capture));
184 }
185 None => {
186 log::trace!("RECV (error: {}) [data capture disabled]", err);
187 }
188 },
189 None => match capture {
190 Some(capture) => {
191 log::trace!("RECV (ok) {}", LogXsoBuf(&capture));
192 }
193 None => (),
194 },
195 }
196}
197
198pub(super) fn log_send(data: &[u8]) {
203 log::trace!("SEND {}", LogXsoBuf(data));
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 use tokio::io::{AsyncBufReadExt, AsyncReadExt};
211
212 #[tokio::test]
213 async fn captures_data_read_via_async_read() {
214 let mut src = &b"Hello World!"[..];
215 let src = tokio::io::BufReader::new(&mut src);
216 let mut src = CaptureBufRead::wrap(src);
217 src.enable_capture();
218
219 let mut dst = [0u8; 8];
220 assert_eq!(src.read(&mut dst[..]).await.unwrap(), 8);
221 assert_eq!(&dst, b"Hello Wo");
222 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello Wo");
223 }
224
225 #[tokio::test]
226 async fn captures_data_read_via_async_buf_read() {
227 let mut src = &b"Hello World!"[..];
228 let src = tokio::io::BufReader::new(&mut src);
229 let mut src = CaptureBufRead::wrap(src);
230 src.enable_capture();
231
232 assert_eq!(src.fill_buf().await.unwrap().len(), 12);
233 assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
235
236 src.consume(5);
237 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello");
238
239 src.consume(6);
240 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b" World");
241 }
242
243 #[tokio::test]
244 async fn discard_capture_drops_consumed_data() {
245 let mut src = &b"Hello World!"[..];
246 let src = tokio::io::BufReader::new(&mut src);
247 let mut src = CaptureBufRead::wrap(src);
248 src.enable_capture();
249
250 assert_eq!(src.fill_buf().await.unwrap().len(), 12);
251 assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
253
254 src.consume(5);
255 Pin::new(&mut src).discard_capture();
256
257 src.consume(6);
258 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b" World");
259 }
260
261 #[tokio::test]
262 async fn captured_data_accumulates() {
263 let mut src = &b"Hello World!"[..];
264 let src = tokio::io::BufReader::new(&mut src);
265 let mut src = CaptureBufRead::wrap(src);
266 src.enable_capture();
267
268 assert_eq!(src.fill_buf().await.unwrap().len(), 12);
269 assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
271
272 src.consume(5);
273 src.consume(6);
274 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello World");
275 }
276}