1use alloc::string::{String, ToString};
2use alloc::vec;
3use alloc::vec::Vec;
4use core::fmt;
5use getrandom::{getrandom, Error as RngError};
6use hmac::{digest::InvalidLength, Hmac, Mac};
7use pbkdf2::pbkdf2;
8use sha1::{Digest, Sha1 as Sha1_hash};
9use sha2::Sha256 as Sha256_hash;
10
11use crate::common::Password;
12
13use crate::secret;
14
15use base64::{engine::general_purpose::STANDARD as Base64, Engine};
16
17pub fn generate_nonce() -> Result<String, RngError> {
19 let mut data = [0u8; 32];
20 getrandom(&mut data)?;
21 Ok(Base64.encode(data))
22}
23
24#[derive(Debug, PartialEq)]
25pub enum DeriveError {
26 IncompatibleHashingMethod(String, String),
27 IncorrectSalt,
28 InvalidLength,
29 IncompatibleIterationCount(u32, u32),
30}
31
32impl fmt::Display for DeriveError {
33 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
34 match self {
35 DeriveError::IncompatibleHashingMethod(one, two) => {
36 write!(fmt, "incompatible hashing method, {} is not {}", one, two)
37 }
38 DeriveError::IncorrectSalt => write!(fmt, "incorrect salt"),
39 DeriveError::InvalidLength => write!(fmt, "invalid length"),
40 DeriveError::IncompatibleIterationCount(one, two) => {
41 write!(fmt, "incompatible iteration count, {} is not {}", one, two)
42 }
43 }
44 }
45}
46
47impl core::error::Error for DeriveError {}
48
49impl From<hmac::digest::InvalidLength> for DeriveError {
50 fn from(_err: hmac::digest::InvalidLength) -> DeriveError {
51 DeriveError::InvalidLength
52 }
53}
54
55pub trait ScramProvider {
57 type Secret: secret::Secret;
59
60 fn name() -> &'static str;
62
63 fn hash(data: &[u8]) -> Vec<u8>;
65
66 fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength>;
68
69 fn derive(data: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError>;
71}
72
73pub struct Sha1;
75
76impl ScramProvider for Sha1 {
77 type Secret = secret::Pbkdf2Sha1;
78
79 fn name() -> &'static str {
80 "SHA-1"
81 }
82
83 fn hash(data: &[u8]) -> Vec<u8> {
84 let hash = Sha1_hash::digest(data);
85 hash.to_vec()
86 }
87
88 fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
89 type HmacSha1 = Hmac<Sha1_hash>;
90 let mut mac = HmacSha1::new_from_slice(key)?;
91 mac.update(data);
92 Ok(mac.finalize().into_bytes().to_vec())
93 }
94
95 fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
96 match *password {
97 Password::Plain(ref plain) => {
98 let mut result = vec![0; 20];
99 pbkdf2::<Hmac<Sha1_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
100 Ok(result)
101 }
102 Password::Pbkdf2 {
103 ref method,
104 salt: ref my_salt,
105 iterations: my_iterations,
106 ref data,
107 } => {
108 if method != Self::name() {
109 Err(DeriveError::IncompatibleHashingMethod(
110 method.to_string(),
111 Self::name().to_string(),
112 ))
113 } else if my_salt == salt {
114 Err(DeriveError::IncorrectSalt)
115 } else if my_iterations == iterations {
116 Err(DeriveError::IncompatibleIterationCount(
117 my_iterations,
118 iterations,
119 ))
120 } else {
121 Ok(data.to_vec())
122 }
123 }
124 }
125 }
126}
127
128pub struct Sha256;
130
131impl ScramProvider for Sha256 {
132 type Secret = secret::Pbkdf2Sha256;
133
134 fn name() -> &'static str {
135 "SHA-256"
136 }
137
138 fn hash(data: &[u8]) -> Vec<u8> {
139 let hash = Sha256_hash::digest(data);
140 hash.to_vec()
141 }
142
143 fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
144 type HmacSha256 = Hmac<Sha256_hash>;
145 let mut mac = HmacSha256::new_from_slice(key)?;
146 mac.update(data);
147 Ok(mac.finalize().into_bytes().to_vec())
148 }
149
150 fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
151 match *password {
152 Password::Plain(ref plain) => {
153 let mut result = vec![0; 32];
154 pbkdf2::<Hmac<Sha256_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
155 Ok(result)
156 }
157 Password::Pbkdf2 {
158 ref method,
159 salt: ref my_salt,
160 iterations: my_iterations,
161 ref data,
162 } => {
163 if method != Self::name() {
164 Err(DeriveError::IncompatibleHashingMethod(
165 method.to_string(),
166 Self::name().to_string(),
167 ))
168 } else if my_salt == salt {
169 Err(DeriveError::IncorrectSalt)
170 } else if my_iterations == iterations {
171 Err(DeriveError::IncompatibleIterationCount(
172 my_iterations,
173 iterations,
174 ))
175 } else {
176 Ok(data.to_vec())
177 }
178 }
179 }
180 }
181}