1use alloc::string::{String, ToString};
2use alloc::vec;
3use alloc::vec::Vec;
4use core::fmt;
5use hmac::{digest::InvalidLength, Hmac, Mac};
6use pbkdf2::pbkdf2;
7use sha1::{Digest, Sha1 as Sha1_hash};
8use sha2::Sha256 as Sha256_hash;
9
10use crate::common::Password;
11
12use crate::secret;
13
14use base64::{engine::general_purpose::STANDARD as Base64, Engine};
15
16pub fn generate_nonce() -> Result<String, getrandom::Error> {
18 let mut data = [0u8; 32];
19 getrandom::fill(&mut data)?;
20 Ok(Base64.encode(data))
21}
22
23#[derive(Debug, PartialEq)]
24pub enum DeriveError {
25 IncompatibleHashingMethod(String, String),
26 IncorrectSalt,
27 InvalidLength,
28 IncompatibleIterationCount(u32, u32),
29}
30
31impl fmt::Display for DeriveError {
32 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
33 match self {
34 DeriveError::IncompatibleHashingMethod(one, two) => {
35 write!(fmt, "incompatible hashing method, {} is not {}", one, two)
36 }
37 DeriveError::IncorrectSalt => write!(fmt, "incorrect salt"),
38 DeriveError::InvalidLength => write!(fmt, "invalid length"),
39 DeriveError::IncompatibleIterationCount(one, two) => {
40 write!(fmt, "incompatible iteration count, {} is not {}", one, two)
41 }
42 }
43 }
44}
45
46impl core::error::Error for DeriveError {}
47
48impl From<hmac::digest::InvalidLength> for DeriveError {
49 fn from(_err: hmac::digest::InvalidLength) -> DeriveError {
50 DeriveError::InvalidLength
51 }
52}
53
54pub trait ScramProvider {
56 type Secret: secret::Secret;
58
59 fn name() -> &'static str;
61
62 fn hash(data: &[u8]) -> Vec<u8>;
64
65 fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength>;
67
68 fn derive(data: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError>;
70}
71
72pub struct Sha1;
74
75impl ScramProvider for Sha1 {
76 type Secret = secret::Pbkdf2Sha1;
77
78 fn name() -> &'static str {
79 "SHA-1"
80 }
81
82 fn hash(data: &[u8]) -> Vec<u8> {
83 let hash = Sha1_hash::digest(data);
84 hash.to_vec()
85 }
86
87 fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
88 type HmacSha1 = Hmac<Sha1_hash>;
89 let mut mac = HmacSha1::new_from_slice(key)?;
90 mac.update(data);
91 Ok(mac.finalize().into_bytes().to_vec())
92 }
93
94 fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
95 match *password {
96 Password::Plain(ref plain) => {
97 let mut result = vec![0; 20];
98 pbkdf2::<Hmac<Sha1_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
99 Ok(result)
100 }
101 Password::Pbkdf2 {
102 ref method,
103 salt: ref my_salt,
104 iterations: my_iterations,
105 ref data,
106 } => {
107 if method != Self::name() {
108 Err(DeriveError::IncompatibleHashingMethod(
109 method.to_string(),
110 Self::name().to_string(),
111 ))
112 } else if my_salt == salt {
113 Err(DeriveError::IncorrectSalt)
114 } else if my_iterations == iterations {
115 Err(DeriveError::IncompatibleIterationCount(
116 my_iterations,
117 iterations,
118 ))
119 } else {
120 Ok(data.to_vec())
121 }
122 }
123 }
124 }
125}
126
127pub struct Sha256;
129
130impl ScramProvider for Sha256 {
131 type Secret = secret::Pbkdf2Sha256;
132
133 fn name() -> &'static str {
134 "SHA-256"
135 }
136
137 fn hash(data: &[u8]) -> Vec<u8> {
138 let hash = Sha256_hash::digest(data);
139 hash.to_vec()
140 }
141
142 fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
143 type HmacSha256 = Hmac<Sha256_hash>;
144 let mut mac = HmacSha256::new_from_slice(key)?;
145 mac.update(data);
146 Ok(mac.finalize().into_bytes().to_vec())
147 }
148
149 fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
150 match *password {
151 Password::Plain(ref plain) => {
152 let mut result = vec![0; 32];
153 pbkdf2::<Hmac<Sha256_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
154 Ok(result)
155 }
156 Password::Pbkdf2 {
157 ref method,
158 salt: ref my_salt,
159 iterations: my_iterations,
160 ref data,
161 } => {
162 if method != Self::name() {
163 Err(DeriveError::IncompatibleHashingMethod(
164 method.to_string(),
165 Self::name().to_string(),
166 ))
167 } else if my_salt == salt {
168 Err(DeriveError::IncorrectSalt)
169 } else if my_iterations == iterations {
170 Err(DeriveError::IncompatibleIterationCount(
171 my_iterations,
172 iterations,
173 ))
174 } else {
175 Ok(data.to_vec())
176 }
177 }
178 }
179 }
180}