1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
// Copyright (c) 2024 Emmanuel Gil Peyrot <linkmauve@linkmauve.fr>
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

use xso::error::{Error, FromElementError};

use crate::ns;
use crate::Element;
use std::str::FromStr;

/// One type of channel-binding, as defined by the IANA:
/// https://www.iana.org/assignments/channel-binding-types/channel-binding-types.xhtml
#[derive(Debug, Clone, PartialEq)]
pub enum Type {
    TlsUnique,
    TlsServerEndPoint,
    TlsUniqueForTelnet,

    /// The EKM value obtained from the current TLS connection.
    ///
    /// See RFC9266.
    TlsExporter,
}

impl FromStr for Type {
    type Err = Error;

    fn from_str(s: &str) -> Result<Type, Self::Err> {
        Ok(match s {
            "tls-unique" => Type::TlsUnique,
            "tls-server-end-point" => Type::TlsServerEndPoint,
            "tls-unique-for-telnet" => Type::TlsUniqueForTelnet,
            "tls-exporter" => Type::TlsExporter,

            _ => return Err(Error::Other("Unknown value '{s}' for 'type' attribute.")),
        })
    }
}

/// Stream feature listing the channel-binding types supported by the server.
#[derive(Debug, Clone, PartialEq)]
// #[xml(namespace = ns::SASL_CB, name = "sasl-channel-binding")]
pub struct SaslChannelBinding {
    /// The list of channel-binding types supported by the server.
    // #[xml(children(namespace = ns::SASL_CB, name = "channel-binding", extract(attribute = "type")))]
    pub types: Vec<Type>,
}

impl TryFrom<Element> for SaslChannelBinding {
    type Error = FromElementError;

    fn try_from(root: Element) -> Result<SaslChannelBinding, Self::Error> {
        check_self!(root, "sasl-channel-binding", SASL_CB);
        check_no_attributes!(root, "sasl-channel-binding");

        let mut types = Vec::new();
        for child in root.children() {
            if child.is("channel-binding", ns::SASL_CB) {
                check_no_children!(child, "channel-binding");
                check_no_unknown_attributes!(child, "channel-binding", ["type"]);
                types.push(get_attr!(child, "type", Required));
            } else {
                return Err(Error::Other("Unknown element in SaslChannelBinding.").into());
            }
        }
        Ok(SaslChannelBinding { types })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::Element;

    #[test]
    fn test_size() {
        assert_size!(Type, 1);
        assert_size!(SaslChannelBinding, 24);
    }

    #[test]
    fn test_simple() {
        let elem: Element = "<sasl-channel-binding xmlns='urn:xmpp:sasl-cb:0'><channel-binding type='tls-server-end-point'/><channel-binding type='tls-exporter'/></sasl-channel-binding>".parse().unwrap();
        let sasl_cb = SaslChannelBinding::try_from(elem).unwrap();
        assert_eq!(sasl_cb.types, [Type::TlsServerEndPoint, Type::TlsExporter]);
    }
}