1use xso::{AsXml, FromXml};
8
9use crate::data_forms::DataForm;
10use crate::disco::{DiscoInfoQuery, DiscoInfoResult, Identity};
11use crate::hashes::{Algo, Hash};
12use crate::ns;
13use crate::presence::PresencePayload;
14use base64::{engine::general_purpose::STANDARD as Base64, Engine};
15use blake2::Blake2bVar;
16use digest::{Digest, Update, VariableOutput};
17use sha1::Sha1;
18use sha2::{Sha256, Sha512};
19use sha3::{Sha3_256, Sha3_512};
20
21#[derive(FromXml, AsXml, Debug, Clone)]
27#[xml(namespace = ns::CAPS, name = "c")]
28pub struct Caps {
29 #[xml(attribute(default))]
31 pub ext: Option<String>,
32
33 #[xml(attribute)]
35 pub node: String,
36
37 #[xml(attribute)]
39 pub hash: Algo,
40
41 #[xml(attribute(codec = Base64))]
44 pub ver: Vec<u8>,
45}
46
47impl PresencePayload for Caps {}
48
49impl Caps {
50 pub fn new<N: Into<String>>(node: N, hash: Hash) -> Caps {
52 Caps {
53 ext: None,
54 node: node.into(),
55 hash: hash.algo,
56 ver: hash.hash,
57 }
58 }
59}
60
61fn compute_item(field: &str) -> Vec<u8> {
62 let mut bytes = field.as_bytes().to_vec();
63 bytes.push(b'<');
64 bytes
65}
66
67fn compute_items<T, F: Fn(&T) -> Vec<u8>>(things: &[T], encode: F) -> Vec<u8> {
68 let mut string: Vec<u8> = vec![];
69 let mut accumulator: Vec<Vec<u8>> = vec![];
70 for thing in things {
71 let bytes = encode(thing);
72 accumulator.push(bytes);
73 }
74 accumulator.sort();
76 for mut bytes in accumulator {
77 string.append(&mut bytes);
78 }
79 string
80}
81
82fn compute_features(features: &[String]) -> Vec<u8> {
83 compute_items(features, |feature| compute_item(&feature))
84}
85
86fn compute_identities(identities: &[Identity]) -> Vec<u8> {
87 compute_items(identities, |identity| {
88 let lang = identity.lang.as_deref().unwrap_or_default();
89 let name = identity.name.as_deref().unwrap_or_default();
90 let string = format!("{}/{}/{}/{}", identity.category, identity.type_, lang, name);
91 let mut vec = string.as_bytes().to_vec();
92 vec.push(b'<');
93 vec
94 })
95}
96
97fn compute_extensions(extensions: &[DataForm]) -> Vec<u8> {
98 compute_items(extensions, |extension| {
99 let mut bytes = if let Some(form_type) = extension.form_type() {
101 form_type.as_bytes().to_vec()
102 } else {
103 vec![]
104 };
105 bytes.push(b'<');
106 for field in &extension.fields {
107 if field.var.as_deref() == Some("FORM_TYPE") {
108 continue;
109 }
110 if let Some(var) = &field.var {
111 bytes.append(&mut compute_item(var));
112 }
113 bytes.append(&mut compute_items(&field.values, |value| {
114 compute_item(value)
115 }));
116 }
117 bytes
118 })
119}
120
121pub fn compute_disco(disco: &DiscoInfoResult) -> Vec<u8> {
128 let features: Vec<_> = disco.features.iter().cloned().collect();
130
131 let identities_string = compute_identities(&disco.identities);
132 let features_string = compute_features(&features);
133 let extensions_string = compute_extensions(&disco.extensions);
134
135 let mut final_string = vec![];
136 final_string.extend(identities_string);
137 final_string.extend(features_string);
138 final_string.extend(extensions_string);
139 final_string
140}
141
142pub fn hash_caps(data: &[u8], algo: Algo) -> Result<Hash, String> {
145 Ok(Hash {
146 hash: match algo {
147 Algo::Sha_1 => {
148 let hash = Sha1::digest(data);
149 hash.to_vec()
150 }
151 Algo::Sha_256 => {
152 let hash = Sha256::digest(data);
153 hash.to_vec()
154 }
155 Algo::Sha_512 => {
156 let hash = Sha512::digest(data);
157 hash.to_vec()
158 }
159 Algo::Sha3_256 => {
160 let hash = Sha3_256::digest(data);
161 hash.to_vec()
162 }
163 Algo::Sha3_512 => {
164 let hash = Sha3_512::digest(data);
165 hash.to_vec()
166 }
167 Algo::Blake2b_256 => {
168 let mut hasher = Blake2bVar::new(32).unwrap();
169 hasher.update(data);
170 let mut vec = vec![0u8; 32];
171 hasher.finalize_variable(&mut vec).unwrap();
172 vec
173 }
174 Algo::Blake2b_512 => {
175 let mut hasher = Blake2bVar::new(64).unwrap();
176 hasher.update(data);
177 let mut vec = vec![0u8; 64];
178 hasher.finalize_variable(&mut vec).unwrap();
179 vec
180 }
181 Algo::Unknown(algo) => return Err(format!("Unknown algorithm: {}.", algo)),
182 },
183 algo,
184 })
185}
186
187pub fn query_caps(caps: Caps) -> DiscoInfoQuery {
190 DiscoInfoQuery {
191 node: Some(format!("{}#{}", caps.node, Base64.encode(&caps.ver))),
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::caps;
199 use minidom::Element;
200 #[cfg(not(feature = "disable-validation"))]
201 use xso::error::{Error, FromElementError};
202
203 #[cfg(target_pointer_width = "32")]
204 #[test]
205 fn test_size() {
206 assert_size!(Caps, 48);
207 }
208
209 #[cfg(target_pointer_width = "64")]
210 #[test]
211 fn test_size() {
212 assert_size!(Caps, 96);
213 }
214
215 #[test]
216 fn test_parse() {
217 let elem: Element = "<c xmlns='http://jabber.org/protocol/caps' hash='sha-256' node='coucou' ver='K1Njy3HZBThlo4moOD5gBGhn0U0oK7/CbfLlIUDi6o4='/>".parse().unwrap();
218 let caps = Caps::try_from(elem).unwrap();
219 assert_eq!(caps.node, String::from("coucou"));
220 assert_eq!(caps.hash, Algo::Sha_256);
221 assert_eq!(
222 caps.ver,
223 Base64
224 .decode("K1Njy3HZBThlo4moOD5gBGhn0U0oK7/CbfLlIUDi6o4=")
225 .unwrap()
226 );
227 }
228
229 #[cfg(not(feature = "disable-validation"))]
230 #[test]
231 fn test_invalid_child() {
232 let elem: Element = "<c xmlns='http://jabber.org/protocol/caps' node='coucou' hash='sha-256' ver='K1Njy3HZBThlo4moOD5gBGhn0U0oK7/CbfLlIUDi6o4='><hash xmlns='urn:xmpp:hashes:2' algo='sha-256'>K1Njy3HZBThlo4moOD5gBGhn0U0oK7/CbfLlIUDi6o4=</hash></c>".parse().unwrap();
233 let error = Caps::try_from(elem).unwrap_err();
234 let message = match error {
235 FromElementError::Invalid(Error::Other(string)) => string,
236 _ => panic!(),
237 };
238 assert_eq!(message, "Unknown child in Caps element.");
239 }
240
241 #[test]
242 fn test_simple() {
243 let elem: Element = "<query xmlns='http://jabber.org/protocol/disco#info'><identity category='client' type='pc'/><feature var='http://jabber.org/protocol/disco#info'/></query>".parse().unwrap();
244 let disco = DiscoInfoResult::try_from(elem).unwrap();
245 let caps = caps::compute_disco(&disco);
246 assert_eq!(caps.len(), 50);
247 }
248
249 #[test]
250 fn test_xep_5_2() {
251 let elem: Element = r#"<query xmlns='http://jabber.org/protocol/disco#info'
252 node='http://psi-im.org#q07IKJEyjvHSyhy//CH0CxmKi8w='>
253 <identity category='client' name='Exodus 0.9.1' type='pc'/>
254 <feature var='http://jabber.org/protocol/caps'/>
255 <feature var='http://jabber.org/protocol/disco#info'/>
256 <feature var='http://jabber.org/protocol/disco#items'/>
257 <feature var='http://jabber.org/protocol/muc'/>
258</query>
259"#
260 .parse()
261 .unwrap();
262
263 let expected = b"client/pc//Exodus 0.9.1<http://jabber.org/protocol/caps<http://jabber.org/protocol/disco#info<http://jabber.org/protocol/disco#items<http://jabber.org/protocol/muc<".to_vec();
264 let disco = DiscoInfoResult::try_from(elem).unwrap();
265 let caps = caps::compute_disco(&disco);
266 assert_eq!(caps, expected);
267
268 let sha_1 = caps::hash_caps(&caps, Algo::Sha_1).unwrap();
269 assert_eq!(
270 sha_1.hash,
271 Base64.decode("QgayPKawpkPSDYmwT/WM94uAlu0=").unwrap()
272 );
273 }
274
275 #[test]
276 fn test_xep_5_3() {
277 let elem: Element = r#"<query xmlns='http://jabber.org/protocol/disco#info'
278 node='http://psi-im.org#q07IKJEyjvHSyhy//CH0CxmKi8w='>
279 <identity xml:lang='en' category='client' name='Psi 0.11' type='pc'/>
280 <identity xml:lang='el' category='client' name='Ψ 0.11' type='pc'/>
281 <feature var='http://jabber.org/protocol/caps'/>
282 <feature var='http://jabber.org/protocol/disco#info'/>
283 <feature var='http://jabber.org/protocol/disco#items'/>
284 <feature var='http://jabber.org/protocol/muc'/>
285 <x xmlns='jabber:x:data' type='result'>
286 <field var='FORM_TYPE' type='hidden'>
287 <value>urn:xmpp:dataforms:softwareinfo</value>
288 </field>
289 <field var='ip_version'>
290 <value>ipv4</value>
291 <value>ipv6</value>
292 </field>
293 <field var='os'>
294 <value>Mac</value>
295 </field>
296 <field var='os_version'>
297 <value>10.5.1</value>
298 </field>
299 <field var='software'>
300 <value>Psi</value>
301 </field>
302 <field var='software_version'>
303 <value>0.11</value>
304 </field>
305 </x>
306</query>
307"#
308 .parse()
309 .unwrap();
310 let expected = b"client/pc/el/\xce\xa8 0.11<client/pc/en/Psi 0.11<http://jabber.org/protocol/caps<http://jabber.org/protocol/disco#info<http://jabber.org/protocol/disco#items<http://jabber.org/protocol/muc<urn:xmpp:dataforms:softwareinfo<ip_version<ipv4<ipv6<os<Mac<os_version<10.5.1<software<Psi<software_version<0.11<".to_vec();
311 let disco = DiscoInfoResult::try_from(elem).unwrap();
312 let caps = caps::compute_disco(&disco);
313 assert_eq!(caps, expected);
314
315 let sha_1 = caps::hash_caps(&caps, Algo::Sha_1).unwrap();
316 assert_eq!(
317 sha_1.hash,
318 Base64.decode("q07IKJEyjvHSyhy//CH0CxmKi8w=").unwrap()
319 );
320 }
321}