Skip to main content

xso_proc/
common.rs

1// Copyright (c) 2024 Jonas Schäfer <jonas@zombofant.net>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7//! Definitions common to both enums and structs
8
9use proc_macro2::TokenStream;
10use quote::{quote, ToTokens};
11use syn::*;
12
13use std::collections::{HashMap, HashSet};
14
15use crate::types::{
16    iter_referenced_names_in_bound, iter_referenced_names_in_generic_param,
17    iter_referenced_names_in_ty, ReferencedIdent,
18};
19
20/// Template which renders to a `xso::fromxml::XmlNameMatcher` value.
21pub(crate) enum XmlNameMatcher {
22    /// Renders as `xso::fromxml::XmlNameMatcher::Any`.
23    #[allow(dead_code)] // We keep it for completeness.
24    Any,
25
26    /// Renders as `xso::fromxml::XmlNameMatcher::InNamespace(#0)`.
27    InNamespace(TokenStream),
28
29    /// Renders as `xso::fromxml::XmlNameMatcher::Specific(#0, #1)`.
30    Specific(TokenStream, TokenStream),
31
32    /// Renders as `#0`.
33    ///
34    /// This is an escape hatch for more complicated constructs, e.g. when
35    /// a superset of multiple matchers is required.
36    Custom(TokenStream),
37}
38
39impl ToTokens for XmlNameMatcher {
40    fn to_tokens(&self, tokens: &mut TokenStream) {
41        match self {
42            Self::Any => tokens.extend(quote! {
43                ::xso::fromxml::XmlNameMatcher::<'static>::Any
44            }),
45            Self::InNamespace(ref namespace) => tokens.extend(quote! {
46                ::xso::fromxml::XmlNameMatcher::<'static>::InNamespace(#namespace)
47            }),
48            Self::Specific(ref namespace, ref name) => tokens.extend(quote! {
49                ::xso::fromxml::XmlNameMatcher::<'static>::Specific(#namespace, #name)
50            }),
51            Self::Custom(ref stream) => tokens.extend(stream.clone()),
52        }
53    }
54}
55
56/// Parts necessary to construct a `::xso::FromXml` implementation.
57pub(crate) struct FromXmlParts {
58    /// Additional items necessary for the implementation.
59    pub(crate) defs: TokenStream,
60
61    /// The body of the `::xso::FromXml::from_xml` function.
62    pub(crate) from_events_body: TokenStream,
63
64    /// The type which is the `::xso::FromXml::Builder`.
65    pub(crate) builder_ty: Type,
66
67    /// The `XmlNameMatcher` to pre-select elements for this implementation.
68    pub(crate) name_matcher: XmlNameMatcher,
69}
70
71/// Parts necessary to construct a `::xso::AsXml` implementation.
72pub(crate) struct AsXmlParts {
73    /// Additional items necessary for the implementation.
74    pub(crate) defs: TokenStream,
75
76    /// The body of the `::xso::AsXml::as_xml_iter` function.
77    pub(crate) as_xml_iter_body: TokenStream,
78
79    /// The type which is the `::xso::AsXml::ItemIter`.
80    pub(crate) item_iter_ty: Type,
81
82    /// The lifetime name used in `item_iter_ty`.
83    pub(crate) item_iter_ty_lifetime: Lifetime,
84}
85
86/// Generic parameters (incl. bounds) which can be interpolated into
87/// `quote!{}`.
88#[derive(Clone, Debug)]
89pub(crate) struct GenericParams {
90    /// The `<` token.
91    pub(crate) lt_token: token::Lt,
92
93    /// The list of parameters, with bounds.
94    pub(crate) params: punctuated::Punctuated<GenericParam, token::Comma>,
95
96    /// The `>` token.
97    pub(crate) gt_token: token::Gt,
98}
99
100impl GenericParams {
101    fn new(item: GenericParam) -> Self {
102        Self {
103            lt_token: <Token![<]>::default(),
104            params: [item].into_iter().collect(),
105            gt_token: <Token![>]>::default(),
106        }
107    }
108
109    /// Split [`syn::Generics`] into [`GenericParams`] and an optional
110    /// [`WhereClause`].
111    fn split(generics: Generics) -> (Option<Self>, Option<WhereClause>) {
112        let Generics {
113            lt_token,
114            params,
115            gt_token,
116            where_clause,
117        } = generics;
118        let this = if params.len() > 0 {
119            if let Some((lt_token, gt_token)) = lt_token.zip(gt_token) {
120                Some(Self {
121                    lt_token,
122                    params,
123                    gt_token,
124                })
125            } else {
126                panic!("no lt and gt tokens, but non-empty generic parameter list!");
127            }
128        } else {
129            None
130        };
131        (this, where_clause)
132    }
133}
134
135impl ToTokens for GenericParams {
136    fn to_tokens(&self, tokens: &mut TokenStream) {
137        self.lt_token.to_tokens(tokens);
138        self.params.to_tokens(tokens);
139        self.gt_token.to_tokens(tokens);
140    }
141}
142
143/// Collection of generics information.
144#[derive(Default, Debug)]
145pub(crate) struct GenericsInfo {
146    /// List of parameters incl. their bounds, enclosed in `< .. >`, for use
147    /// after the `impl` keyword.
148    pub(crate) decl: Option<GenericParams>,
149
150    /// List of parameters without their bounds, for use as ref_ when
151    /// referencing the type.
152    pub(crate) ref_: Option<AngleBracketedGenericArguments>,
153
154    /// Where clause of the type, if any.
155    pub(crate) where_clause: Option<WhereClause>,
156
157    /// Index of the first bound which was added via add_bound_if_generic.
158    new_bounds_at: usize,
159}
160
161impl GenericsInfo {
162    /// Extract the relevant parts from an [`Item`]'s [`Generics`] so
163    /// that they can be used inside [`quote::quote`] to form `impl` items.
164    pub(crate) fn bake(generics: Generics) -> Self {
165        let (decl, where_clause) = GenericParams::split(generics);
166        let ref_ = if let Some(decl) = decl
167            .as_ref()
168            .and_then(|x| if x.params.len() > 0 { Some(x) } else { None })
169            .as_ref()
170        {
171            let mut args = punctuated::Punctuated::new();
172            for pair in decl.params.pairs() {
173                args.push_value(match pair.value() {
174                    GenericParam::Lifetime(lt) => GenericArgument::Lifetime(lt.lifetime.clone()),
175                    GenericParam::Type(ty) => GenericArgument::Type(Type::Path(TypePath {
176                        qself: None,
177                        path: ty.ident.clone().into(),
178                    })),
179                    GenericParam::Const(cst) => GenericArgument::Const(Expr::Path(ExprPath {
180                        attrs: Vec::new(),
181                        qself: None,
182                        path: cst.ident.clone().into(),
183                    })),
184                });
185                if let Some(punct) = pair.punct() {
186                    args.push_punct(**punct);
187                }
188            }
189
190            Some(AngleBracketedGenericArguments {
191                colon2_token: None,
192                lt_token: decl.lt_token.clone(),
193                args,
194                gt_token: decl.gt_token.clone(),
195            })
196        } else {
197            None
198        };
199        Self {
200            decl,
201            ref_,
202            new_bounds_at: where_clause
203                .as_ref()
204                .map(|x| x.predicates.len())
205                .unwrap_or(0),
206            where_clause,
207        }
208    }
209
210    fn where_clause_mut(&mut self) -> &mut WhereClause {
211        self.where_clause.get_or_insert_with(|| WhereClause {
212            where_token: <Token![where]>::default(),
213            predicates: punctuated::Punctuated::default(),
214        })
215    }
216
217    pub fn add_bound_if_generic(&mut self, ty: &Type, trait_: Path) {
218        let Some(decl) = self.decl.as_ref() else {
219            // no declarations -> no generics.
220            return;
221        };
222
223        let mut known_names = HashSet::new();
224        known_names.extend(decl.params.iter().filter_map(|x| match x {
225            GenericParam::Type(ty) => Some(&ty.ident),
226            _ => None,
227        }));
228
229        let mut is_generic = false;
230        iter_referenced_names_in_ty(&ty, &mut |reference| match reference {
231            ReferencedIdent::Name(name) => {
232                if known_names.contains(name) {
233                    is_generic = true;
234                }
235                Ok(())
236            }
237            _ => Ok(()),
238        })
239        .unwrap();
240        if !is_generic {
241            return;
242        }
243
244        self.where_clause_mut()
245            .predicates
246            .push(WherePredicate::Type(PredicateType {
247                lifetimes: None,
248                bounded_ty: ty.clone(),
249                colon_token: <Token![:]>::default(),
250                bounds: [TypeParamBound::Trait(TraitBound {
251                    paren_token: None,
252                    modifier: TraitBoundModifier::None,
253                    lifetimes: None,
254                    path: trait_,
255                })]
256                .into_iter()
257                .collect(),
258            }))
259    }
260
261    pub fn insert_lifetime(&mut self, lifetime: &Lifetime) {
262        let param = GenericParam::Lifetime(LifetimeParam {
263            attrs: vec![],
264            lifetime: lifetime.clone(),
265            colon_token: None,
266            bounds: punctuated::Punctuated::default(),
267        });
268        match self.decl.as_mut() {
269            Some(decl) => {
270                for param in decl.params.iter_mut() {
271                    match param {
272                        GenericParam::Lifetime(other_lifetime) => {
273                            other_lifetime.colon_token.get_or_insert_default();
274                            other_lifetime.bounds.push(lifetime.clone());
275                        }
276                        GenericParam::Type(ty) => {
277                            ty.colon_token.get_or_insert_default();
278                            ty.bounds.push(TypeParamBound::Lifetime(lifetime.clone()));
279                        }
280                        _ => (),
281                    }
282                }
283                decl.params.insert(0, param)
284            }
285            None => self.decl = Some(GenericParams::new(param)),
286        }
287
288        let arg = GenericArgument::Lifetime(lifetime.clone());
289        match self.ref_.as_mut() {
290            Some(v) => v.args.insert(0, arg),
291            None => {
292                self.ref_ = Some(AngleBracketedGenericArguments {
293                    colon2_token: None,
294                    lt_token: <Token![<]>::default(),
295                    args: [arg].into_iter().collect(),
296                    gt_token: <Token![>]>::default(),
297                })
298            }
299        }
300    }
301
302    pub(crate) fn path_arguments(&self) -> PathArguments {
303        match self.ref_.as_ref() {
304            None => PathArguments::None,
305            Some(v) => PathArguments::AngleBracketed(v.clone()),
306        }
307    }
308
309    pub fn ty_with_arguments(&self, ident: Ident) -> Type {
310        Type::Path(TypePath {
311            qself: None,
312            path: Path {
313                leading_colon: None,
314                segments: [PathSegment {
315                    ident,
316                    arguments: self.path_arguments(),
317                }]
318                .into_iter()
319                .collect(),
320            },
321        })
322    }
323
324    pub fn subscope(&self) -> Self {
325        Self {
326            decl: self.decl.clone(),
327            ref_: self.ref_.clone(),
328            where_clause: self.where_clause.clone(),
329            new_bounds_at: self
330                .where_clause
331                .as_ref()
332                .map(|x| x.predicates.len())
333                .unwrap_or(0),
334        }
335    }
336
337    /// Return a subset of this `GenericsInfo` which only contains parameters
338    /// and their constraints relevant for the given `ty`.
339    ///
340    /// This preserves:
341    /// - parameters which are referenced in `ty` (and their argument form)
342    /// - parameters which are referenced by trait bounds of parameters
343    ///   referenced by `ty`
344    /// - where clauses where all referenced parameters occur in the set of
345    ///   parameters included.
346    ///
347    /// The same holds for lifetimes and consts.
348    pub(crate) fn scoped_for(&self, ty: &Type) -> Result<Self> {
349        let (decl, ref_) = if let Some(decl) = self.decl.as_ref() {
350            if let Some(ref_) = self.ref_.as_ref() {
351                (decl, ref_)
352            } else {
353                panic!("internal xso-proc error: ref_ is None, but decl is not.");
354            }
355        } else {
356            assert!(self.ref_.is_none());
357            return Ok(self.subscope());
358        };
359
360        // First, we gather all declared names (types, consts and lifetimes).
361        let mut declarations = HashMap::new();
362        for param in decl.params.iter() {
363            match param {
364                GenericParam::Type(ty) => {
365                    declarations.insert(ReferencedIdent::Name(&ty.ident), param)
366                }
367                GenericParam::Const(cst) => {
368                    declarations.insert(ReferencedIdent::Name(&cst.ident), param)
369                }
370                GenericParam::Lifetime(lt) => {
371                    declarations.insert(ReferencedIdent::Lifetime(&lt.lifetime.ident), param)
372                }
373            };
374        }
375
376        // Then, we visit all names in `ty` and add them to the `used_names`
377        // hash set.
378        let mut used_names = HashSet::new();
379        match crate::types::iter_referenced_names_in_ty(ty, &mut |ident| {
380            let declared = declarations.contains_key(&ident);
381            if declared {
382                used_names.insert(ident);
383            }
384            Ok(())
385        }) {
386            Ok(()) => (),
387            Err(e) => {
388                return Err({
389                    let mut my_err = Error::new_spanned(
390                        ty,
391                        "cannot build subset of generic types needed for this type",
392                    );
393                    my_err.combine(e);
394                    my_err
395                })
396            }
397        };
398
399        // Now we iterate over `used_names` and add them to `keep`, while also
400        // looking for any further names referenced in their bounds which
401        // we'll then also have to keep.
402        let mut keep = HashSet::new();
403        while let Some(name) = used_names.iter().next() {
404            let name = *name;
405            used_names.remove(&name);
406            let rhs = declarations.get(&name).unwrap();
407            match iter_referenced_names_in_generic_param(&rhs, &mut |ident| {
408                if keep.contains(&ident) {
409                    // ^ this is important, otherwise we end up in an infinite
410                    // loop when we have cyclic trait bounds (like
411                    // `A: From<B>, B: From<A>`).
412                    return Ok(());
413                }
414                let declared = declarations.contains_key(&ident);
415                if declared {
416                    used_names.insert(ident);
417                }
418                Ok(())
419            }) {
420                Ok(()) => (),
421                Err(e) => {
422                    return Err({
423                        let mut my_err = Error::new_spanned(
424                            ty,
425                            "cannot build subset of generic types needed for this type",
426                        );
427                        my_err.combine(e);
428                        my_err
429                    })
430                }
431            };
432            keep.insert(name);
433        }
434
435        // Now that we know all the things, we keep those which are, in fact,
436        // relevant.
437        let mut result_params = punctuated::Punctuated::new();
438        let mut result_arguments = punctuated::Punctuated::new();
439        for (param, argument) in decl.params.pairs().zip(ref_.args.pairs()) {
440            let key = match param.value() {
441                GenericParam::Type(ty) => ReferencedIdent::Name(&ty.ident),
442                GenericParam::Const(cst) => ReferencedIdent::Name(&cst.ident),
443                GenericParam::Lifetime(lt) => ReferencedIdent::Lifetime(&lt.lifetime.ident),
444            };
445
446            if !keep.contains(&key) {
447                continue;
448            }
449
450            result_params.push_value((*param.value()).clone());
451            if let Some(punct) = param.punct() {
452                result_params.push_punct(**punct);
453            }
454
455            result_arguments.push_value((*argument.value()).clone());
456            if let Some(punct) = argument.punct() {
457                result_arguments.push_punct(**punct);
458            }
459        }
460
461        // And finally, we gather all where clauses which only contain names
462        // which are already in the set to keep (ignoring names which
463        // haven't been declared by us, such as `std`).
464        let result_where = if let Some(where_clause) = self.where_clause.as_ref() {
465            let mut result_predicates = punctuated::Punctuated::new();
466            // Note that I'm not 100% sure this logic is complete; there may be
467            // cases where additional clauses need to be kept... For now I
468            // went with "keep only those where all names are already kept"
469            // because I assume that we'd run into issues with
470            // "type parameter not used" errors.
471            for predicate in where_clause.predicates.pairs() {
472                // we start with true, and once we hit a declared name which
473                // isn't kept, we set it to false.
474                let mut keep_this_predicate = true;
475                match predicate.value() {
476                    WherePredicate::Lifetime(predicate) => {
477                        keep_this_predicate = keep_this_predicate
478                            && keep.contains(&ReferencedIdent::Lifetime(&predicate.lifetime.ident));
479                        for bound in predicate.bounds.iter() {
480                            let name = ReferencedIdent::Lifetime(&bound.ident);
481                            if declarations.contains_key(&name) {
482                                keep_this_predicate = keep_this_predicate && keep.contains(&name);
483                            }
484                        }
485                    }
486                    WherePredicate::Type(predicate) => {
487                        iter_referenced_names_in_ty(&predicate.bounded_ty, &mut |ident| {
488                            if declarations.contains_key(&ident) {
489                                keep_this_predicate = keep_this_predicate && keep.contains(&ident);
490                            }
491                            Ok(())
492                        })?;
493                        for bound in predicate.bounds.iter() {
494                            iter_referenced_names_in_bound(&bound, &mut |ident| {
495                                if declarations.contains_key(&ident) {
496                                    keep_this_predicate =
497                                        keep_this_predicate && keep.contains(&ident);
498                                }
499                                Ok(())
500                            })?;
501                        }
502                    }
503                    other => {
504                        return Err({
505                            let mut my_err = Error::new_spanned(
506                                ty,
507                                "cannot build subset of generic types needed for this type",
508                            );
509                            my_err.combine(Error::new_spanned(
510                                other,
511                                "this kind of where predicate is not supported",
512                            ));
513                            my_err
514                        })
515                    }
516                }
517                if keep_this_predicate {
518                    result_predicates.push_value((*predicate.value()).clone());
519                    if let Some(punct) = predicate.punct() {
520                        result_predicates.push_punct(**punct);
521                    }
522                }
523            }
524
525            if result_predicates.len() > 0 {
526                Some(WhereClause {
527                    where_token: where_clause.where_token.clone(),
528                    predicates: result_predicates,
529                })
530            } else {
531                None
532            }
533        } else {
534            None
535        };
536
537        Ok(Self {
538            decl: if result_params.len() > 0 {
539                Some(GenericParams {
540                    lt_token: decl.lt_token.clone(),
541                    params: result_params,
542                    gt_token: decl.gt_token.clone(),
543                })
544            } else {
545                None
546            },
547            ref_: if result_arguments.len() > 0 {
548                Some(AngleBracketedGenericArguments {
549                    colon2_token: ref_.colon2_token.clone(),
550                    lt_token: ref_.lt_token.clone(),
551                    args: result_arguments,
552                    gt_token: ref_.gt_token.clone(),
553                })
554            } else {
555                None
556            },
557            new_bounds_at: result_where
558                .as_ref()
559                .map(|x| x.predicates.len())
560                .unwrap_or(0),
561            where_clause: result_where,
562        })
563    }
564
565    pub fn into_new_bounds(self) -> impl Iterator<Item = WherePredicate> {
566        self.where_clause
567            .map(|x| x.predicates)
568            .unwrap_or_else(|| punctuated::Punctuated::default())
569            .into_iter()
570            .skip(self.new_bounds_at)
571    }
572
573    pub fn extend_bounds(&mut self, iter: impl IntoIterator<Item = WherePredicate>) {
574        self.where_clause_mut().predicates.extend(iter)
575    }
576}
577
578/// Trait describing the definition of the XML (de-)serialisation for an item
579/// (enum or struct).
580pub(crate) trait ItemDef {
581    /// Construct the parts necessary for the caller to build an
582    /// `xso::FromXml` implementation for the item.
583    fn make_from_events_builder(
584        &self,
585        vis: &Visibility,
586        generics: &mut GenericsInfo,
587        name_ident: &Ident,
588        attrs_ident: &Ident,
589    ) -> Result<FromXmlParts>;
590
591    /// Construct the parts necessary for the caller to build an `xso::AsXml`
592    /// implementation for the item.
593    fn make_as_xml_iter(&self, vis: &Visibility, generics: &mut GenericsInfo)
594        -> Result<AsXmlParts>;
595
596    /// Return true iff the user requested debug output.
597    fn debug(&self) -> bool;
598}