sasl/common/
scram.rs

1use alloc::string::{String, ToString};
2use alloc::vec;
3use alloc::vec::Vec;
4use core::fmt;
5use getrandom::{getrandom, Error as RngError};
6use hmac::{digest::InvalidLength, Hmac, Mac};
7use pbkdf2::pbkdf2;
8use sha1::{Digest, Sha1 as Sha1_hash};
9use sha2::Sha256 as Sha256_hash;
10
11use crate::common::Password;
12
13use crate::secret;
14
15use base64::{engine::general_purpose::STANDARD as Base64, Engine};
16
17/// Generate a nonce for SCRAM authentication.
18pub fn generate_nonce() -> Result<String, RngError> {
19    let mut data = [0u8; 32];
20    getrandom(&mut data)?;
21    Ok(Base64.encode(data))
22}
23
24#[derive(Debug, PartialEq)]
25pub enum DeriveError {
26    IncompatibleHashingMethod(String, String),
27    IncorrectSalt,
28    InvalidLength,
29    IncompatibleIterationCount(u32, u32),
30}
31
32impl fmt::Display for DeriveError {
33    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
34        match self {
35            DeriveError::IncompatibleHashingMethod(one, two) => {
36                write!(fmt, "incompatible hashing method, {} is not {}", one, two)
37            }
38            DeriveError::IncorrectSalt => write!(fmt, "incorrect salt"),
39            DeriveError::InvalidLength => write!(fmt, "invalid length"),
40            DeriveError::IncompatibleIterationCount(one, two) => {
41                write!(fmt, "incompatible iteration count, {} is not {}", one, two)
42            }
43        }
44    }
45}
46
47impl core::error::Error for DeriveError {}
48
49impl From<hmac::digest::InvalidLength> for DeriveError {
50    fn from(_err: hmac::digest::InvalidLength) -> DeriveError {
51        DeriveError::InvalidLength
52    }
53}
54
55/// A trait which defines the needed methods for SCRAM.
56pub trait ScramProvider {
57    /// The kind of secret this `ScramProvider` requires.
58    type Secret: secret::Secret;
59
60    /// The name of the hash function.
61    fn name() -> &'static str;
62
63    /// A function which hashes the data using the hash function.
64    fn hash(data: &[u8]) -> Vec<u8>;
65
66    /// A function which performs an HMAC using the hash function.
67    fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength>;
68
69    /// A function which does PBKDF2 key derivation using the hash function.
70    fn derive(data: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError>;
71}
72
73/// A `ScramProvider` which provides SCRAM-SHA-1 and SCRAM-SHA-1-PLUS
74pub struct Sha1;
75
76impl ScramProvider for Sha1 {
77    type Secret = secret::Pbkdf2Sha1;
78
79    fn name() -> &'static str {
80        "SHA-1"
81    }
82
83    fn hash(data: &[u8]) -> Vec<u8> {
84        let hash = Sha1_hash::digest(data);
85        hash.to_vec()
86    }
87
88    fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
89        type HmacSha1 = Hmac<Sha1_hash>;
90        let mut mac = HmacSha1::new_from_slice(key)?;
91        mac.update(data);
92        Ok(mac.finalize().into_bytes().to_vec())
93    }
94
95    fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
96        match *password {
97            Password::Plain(ref plain) => {
98                let mut result = vec![0; 20];
99                pbkdf2::<Hmac<Sha1_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
100                Ok(result)
101            }
102            Password::Pbkdf2 {
103                ref method,
104                salt: ref my_salt,
105                iterations: my_iterations,
106                ref data,
107            } => {
108                if method != Self::name() {
109                    Err(DeriveError::IncompatibleHashingMethod(
110                        method.to_string(),
111                        Self::name().to_string(),
112                    ))
113                } else if my_salt == salt {
114                    Err(DeriveError::IncorrectSalt)
115                } else if my_iterations == iterations {
116                    Err(DeriveError::IncompatibleIterationCount(
117                        my_iterations,
118                        iterations,
119                    ))
120                } else {
121                    Ok(data.to_vec())
122                }
123            }
124        }
125    }
126}
127
128/// A `ScramProvider` which provides SCRAM-SHA-256 and SCRAM-SHA-256-PLUS
129pub struct Sha256;
130
131impl ScramProvider for Sha256 {
132    type Secret = secret::Pbkdf2Sha256;
133
134    fn name() -> &'static str {
135        "SHA-256"
136    }
137
138    fn hash(data: &[u8]) -> Vec<u8> {
139        let hash = Sha256_hash::digest(data);
140        hash.to_vec()
141    }
142
143    fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
144        type HmacSha256 = Hmac<Sha256_hash>;
145        let mut mac = HmacSha256::new_from_slice(key)?;
146        mac.update(data);
147        Ok(mac.finalize().into_bytes().to_vec())
148    }
149
150    fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
151        match *password {
152            Password::Plain(ref plain) => {
153                let mut result = vec![0; 32];
154                pbkdf2::<Hmac<Sha256_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
155                Ok(result)
156            }
157            Password::Pbkdf2 {
158                ref method,
159                salt: ref my_salt,
160                iterations: my_iterations,
161                ref data,
162            } => {
163                if method != Self::name() {
164                    Err(DeriveError::IncompatibleHashingMethod(
165                        method.to_string(),
166                        Self::name().to_string(),
167                    ))
168                } else if my_salt == salt {
169                    Err(DeriveError::IncorrectSalt)
170                } else if my_iterations == iterations {
171                    Err(DeriveError::IncompatibleIterationCount(
172                        my_iterations,
173                        iterations,
174                    ))
175                } else {
176                    Ok(data.to_vec())
177                }
178            }
179        }
180    }
181}