1use xso::{AsXml, FromXml};
8
9use crate::data_forms::DataForm;
10use crate::disco::{DiscoInfoQuery, DiscoInfoResult, Identity};
11use crate::hashes::{Algo, Hash};
12use crate::ns;
13use crate::presence::PresencePayload;
14use base64::{engine::general_purpose::STANDARD as Base64, Engine};
15use blake2::{
16 digest::{Update, VariableOutput},
17 Blake2bVar,
18};
19use digest::Digest;
20use sha1::Sha1;
21use sha2::{Sha256, Sha512};
22use sha3::{Sha3_256, Sha3_512};
23
24#[derive(FromXml, AsXml, Debug, Clone)]
30#[xml(namespace = ns::CAPS, name = "c")]
31pub struct Caps {
32 #[xml(attribute(default))]
34 pub ext: Option<String>,
35
36 #[xml(attribute)]
38 pub node: String,
39
40 #[xml(attribute)]
42 pub hash: Algo,
43
44 #[xml(attribute(codec = Base64))]
47 pub ver: Vec<u8>,
48}
49
50impl PresencePayload for Caps {}
51
52impl Caps {
53 pub fn new<N: Into<String>>(node: N, hash: Hash) -> Caps {
55 Caps {
56 ext: None,
57 node: node.into(),
58 hash: hash.algo,
59 ver: hash.hash,
60 }
61 }
62}
63
64fn compute_item(field: &str) -> Vec<u8> {
65 let mut bytes = field.as_bytes().to_vec();
66 bytes.push(b'<');
67 bytes
68}
69
70fn compute_items<T, F: Fn(&T) -> Vec<u8>>(things: &[T], encode: F) -> Vec<u8> {
71 let mut string: Vec<u8> = vec![];
72 let mut accumulator: Vec<Vec<u8>> = vec![];
73 for thing in things {
74 let bytes = encode(thing);
75 accumulator.push(bytes);
76 }
77 accumulator.sort();
79 for mut bytes in accumulator {
80 string.append(&mut bytes);
81 }
82 string
83}
84
85fn compute_features(features: &[String]) -> Vec<u8> {
86 compute_items(features, |feature| compute_item(&feature))
87}
88
89fn compute_identities(identities: &[Identity]) -> Vec<u8> {
90 compute_items(identities, |identity| {
91 let lang = identity.lang.as_deref().unwrap_or_default();
92 let name = identity.name.as_deref().unwrap_or_default();
93 let string = format!("{}/{}/{}/{}", identity.category, identity.type_, lang, name);
94 let mut vec = string.as_bytes().to_vec();
95 vec.push(b'<');
96 vec
97 })
98}
99
100fn compute_extensions(extensions: &[DataForm]) -> Vec<u8> {
101 compute_items(extensions, |extension| {
102 let mut bytes = if let Some(form_type) = extension.form_type() {
104 form_type.as_bytes().to_vec()
105 } else {
106 vec![]
107 };
108 bytes.push(b'<');
109 for field in &extension.fields {
110 if field.var.as_deref() == Some("FORM_TYPE") {
111 continue;
112 }
113 if let Some(var) = &field.var {
114 bytes.append(&mut compute_item(var));
115 }
116 bytes.append(&mut compute_items(&field.values, |value| {
117 compute_item(value)
118 }));
119 }
120 bytes
121 })
122}
123
124pub fn compute_disco(disco: &DiscoInfoResult) -> Vec<u8> {
131 let features: Vec<_> = disco.features.iter().cloned().collect();
133
134 let identities_string = compute_identities(&disco.identities);
135 let features_string = compute_features(&features);
136 let extensions_string = compute_extensions(&disco.extensions);
137
138 let mut final_string = vec![];
139 final_string.extend(identities_string);
140 final_string.extend(features_string);
141 final_string.extend(extensions_string);
142 final_string
143}
144
145pub fn hash_caps(data: &[u8], algo: Algo) -> Result<Hash, String> {
148 Ok(Hash {
149 hash: match algo {
150 Algo::Sha_1 => {
151 let hash = Sha1::digest(data);
152 hash.to_vec()
153 }
154 Algo::Sha_256 => {
155 let hash = Sha256::digest(data);
156 hash.to_vec()
157 }
158 Algo::Sha_512 => {
159 let hash = Sha512::digest(data);
160 hash.to_vec()
161 }
162 Algo::Sha3_256 => {
163 let hash = Sha3_256::digest(data);
164 hash.to_vec()
165 }
166 Algo::Sha3_512 => {
167 let hash = Sha3_512::digest(data);
168 hash.to_vec()
169 }
170 Algo::Blake2b_256 => {
171 let mut hasher = Blake2bVar::new(32).unwrap();
172 hasher.update(data);
173 let mut vec = vec![0u8; 32];
174 hasher.finalize_variable(&mut vec).unwrap();
175 vec
176 }
177 Algo::Blake2b_512 => {
178 let mut hasher = Blake2bVar::new(64).unwrap();
179 hasher.update(data);
180 let mut vec = vec![0u8; 64];
181 hasher.finalize_variable(&mut vec).unwrap();
182 vec
183 }
184 Algo::Unknown(algo) => return Err(format!("Unknown algorithm: {}.", algo)),
185 },
186 algo,
187 })
188}
189
190pub fn query_caps(caps: Caps) -> DiscoInfoQuery {
193 DiscoInfoQuery {
194 node: Some(format!("{}#{}", caps.node, Base64.encode(&caps.ver))),
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::caps;
202 use minidom::Element;
203 #[cfg(feature = "pedantic")]
204 use xso::error::{Error, FromElementError};
205
206 #[cfg(target_pointer_width = "32")]
207 #[test]
208 fn test_size() {
209 assert_size!(Caps, 48);
210 }
211
212 #[cfg(target_pointer_width = "64")]
213 #[test]
214 fn test_size() {
215 assert_size!(Caps, 96);
216 }
217
218 #[test]
219 fn test_parse() {
220 let elem: Element = "<c xmlns='http://jabber.org/protocol/caps' hash='sha-256' node='coucou' ver='K1Njy3HZBThlo4moOD5gBGhn0U0oK7/CbfLlIUDi6o4='/>".parse().unwrap();
221 let caps = Caps::try_from(elem).unwrap();
222 assert_eq!(caps.node, String::from("coucou"));
223 assert_eq!(caps.hash, Algo::Sha_256);
224 assert_eq!(
225 caps.ver,
226 Base64
227 .decode("K1Njy3HZBThlo4moOD5gBGhn0U0oK7/CbfLlIUDi6o4=")
228 .unwrap()
229 );
230 }
231
232 #[cfg(feature = "pedantic")]
233 #[test]
234 fn test_invalid_child() {
235 let elem: Element = "<c xmlns='http://jabber.org/protocol/caps' node='coucou' hash='sha-256' ver='K1Njy3HZBThlo4moOD5gBGhn0U0oK7/CbfLlIUDi6o4='><hash xmlns='urn:xmpp:hashes:2' algo='sha-256'>K1Njy3HZBThlo4moOD5gBGhn0U0oK7/CbfLlIUDi6o4=</hash></c>".parse().unwrap();
236 let error = Caps::try_from(elem).unwrap_err();
237 let message = match error {
238 FromElementError::Invalid(Error::Other(string)) => string,
239 _ => panic!(),
240 };
241 assert_eq!(message, "Unknown child in Caps element.");
242 }
243
244 #[test]
245 fn test_simple() {
246 let elem: Element = "<query xmlns='http://jabber.org/protocol/disco#info'><identity category='client' type='pc'/><feature var='http://jabber.org/protocol/disco#info'/></query>".parse().unwrap();
247 let disco = DiscoInfoResult::try_from(elem).unwrap();
248 let caps = caps::compute_disco(&disco);
249 assert_eq!(caps.len(), 50);
250 }
251
252 #[test]
253 fn test_xep_5_2() {
254 let elem: Element = r#"<query xmlns='http://jabber.org/protocol/disco#info'
255 node='http://psi-im.org#q07IKJEyjvHSyhy//CH0CxmKi8w='>
256 <identity category='client' name='Exodus 0.9.1' type='pc'/>
257 <feature var='http://jabber.org/protocol/caps'/>
258 <feature var='http://jabber.org/protocol/disco#info'/>
259 <feature var='http://jabber.org/protocol/disco#items'/>
260 <feature var='http://jabber.org/protocol/muc'/>
261</query>
262"#
263 .parse()
264 .unwrap();
265
266 let expected = b"client/pc//Exodus 0.9.1<http://jabber.org/protocol/caps<http://jabber.org/protocol/disco#info<http://jabber.org/protocol/disco#items<http://jabber.org/protocol/muc<".to_vec();
267 let disco = DiscoInfoResult::try_from(elem).unwrap();
268 let caps = caps::compute_disco(&disco);
269 assert_eq!(caps, expected);
270
271 let sha_1 = caps::hash_caps(&caps, Algo::Sha_1).unwrap();
272 assert_eq!(
273 sha_1.hash,
274 Base64.decode("QgayPKawpkPSDYmwT/WM94uAlu0=").unwrap()
275 );
276 }
277
278 #[test]
279 fn test_xep_5_3() {
280 let elem: Element = r#"<query xmlns='http://jabber.org/protocol/disco#info'
281 node='http://psi-im.org#q07IKJEyjvHSyhy//CH0CxmKi8w='>
282 <identity xml:lang='en' category='client' name='Psi 0.11' type='pc'/>
283 <identity xml:lang='el' category='client' name='Ψ 0.11' type='pc'/>
284 <feature var='http://jabber.org/protocol/caps'/>
285 <feature var='http://jabber.org/protocol/disco#info'/>
286 <feature var='http://jabber.org/protocol/disco#items'/>
287 <feature var='http://jabber.org/protocol/muc'/>
288 <x xmlns='jabber:x:data' type='result'>
289 <field var='FORM_TYPE' type='hidden'>
290 <value>urn:xmpp:dataforms:softwareinfo</value>
291 </field>
292 <field var='ip_version'>
293 <value>ipv4</value>
294 <value>ipv6</value>
295 </field>
296 <field var='os'>
297 <value>Mac</value>
298 </field>
299 <field var='os_version'>
300 <value>10.5.1</value>
301 </field>
302 <field var='software'>
303 <value>Psi</value>
304 </field>
305 <field var='software_version'>
306 <value>0.11</value>
307 </field>
308 </x>
309</query>
310"#
311 .parse()
312 .unwrap();
313 let expected = b"client/pc/el/\xce\xa8 0.11<client/pc/en/Psi 0.11<http://jabber.org/protocol/caps<http://jabber.org/protocol/disco#info<http://jabber.org/protocol/disco#items<http://jabber.org/protocol/muc<urn:xmpp:dataforms:softwareinfo<ip_version<ipv4<ipv6<os<Mac<os_version<10.5.1<software<Psi<software_version<0.11<".to_vec();
314 let disco = DiscoInfoResult::try_from(elem).unwrap();
315 let caps = caps::compute_disco(&disco);
316 assert_eq!(caps, expected);
317
318 let sha_1 = caps::hash_caps(&caps, Algo::Sha_1).unwrap();
319 assert_eq!(
320 sha_1.hash,
321 Base64.decode("q07IKJEyjvHSyhy//CH0CxmKi8w=").unwrap()
322 );
323 }
324}