use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use core::fmt;
use getrandom::{getrandom, Error as RngError};
use hmac::{digest::InvalidLength, Hmac, Mac};
use pbkdf2::pbkdf2;
use sha1::{Digest, Sha1 as Sha1_hash};
use sha2::Sha256 as Sha256_hash;
use crate::common::Password;
use crate::secret;
use base64::{engine::general_purpose::STANDARD as Base64, Engine};
pub fn generate_nonce() -> Result<String, RngError> {
let mut data = [0u8; 32];
getrandom(&mut data)?;
Ok(Base64.encode(data))
}
#[derive(Debug, PartialEq)]
pub enum DeriveError {
IncompatibleHashingMethod(String, String),
IncorrectSalt,
InvalidLength,
IncompatibleIterationCount(u32, u32),
}
impl fmt::Display for DeriveError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
DeriveError::IncompatibleHashingMethod(one, two) => {
write!(fmt, "incompatible hashing method, {} is not {}", one, two)
}
DeriveError::IncorrectSalt => write!(fmt, "incorrect salt"),
DeriveError::InvalidLength => write!(fmt, "invalid length"),
DeriveError::IncompatibleIterationCount(one, two) => {
write!(fmt, "incompatible iteration count, {} is not {}", one, two)
}
}
}
}
impl core::error::Error for DeriveError {}
impl From<hmac::digest::InvalidLength> for DeriveError {
fn from(_err: hmac::digest::InvalidLength) -> DeriveError {
DeriveError::InvalidLength
}
}
pub trait ScramProvider {
type Secret: secret::Secret;
fn name() -> &'static str;
fn hash(data: &[u8]) -> Vec<u8>;
fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength>;
fn derive(data: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError>;
}
pub struct Sha1;
impl ScramProvider for Sha1 {
type Secret = secret::Pbkdf2Sha1;
fn name() -> &'static str {
"SHA-1"
}
fn hash(data: &[u8]) -> Vec<u8> {
let hash = Sha1_hash::digest(data);
hash.to_vec()
}
fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
type HmacSha1 = Hmac<Sha1_hash>;
let mut mac = HmacSha1::new_from_slice(key)?;
mac.update(data);
Ok(mac.finalize().into_bytes().to_vec())
}
fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
match *password {
Password::Plain(ref plain) => {
let mut result = vec![0; 20];
pbkdf2::<Hmac<Sha1_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
Ok(result)
}
Password::Pbkdf2 {
ref method,
salt: ref my_salt,
iterations: my_iterations,
ref data,
} => {
if method != Self::name() {
Err(DeriveError::IncompatibleHashingMethod(
method.to_string(),
Self::name().to_string(),
))
} else if my_salt == salt {
Err(DeriveError::IncorrectSalt)
} else if my_iterations == iterations {
Err(DeriveError::IncompatibleIterationCount(
my_iterations,
iterations,
))
} else {
Ok(data.to_vec())
}
}
}
}
}
pub struct Sha256;
impl ScramProvider for Sha256 {
type Secret = secret::Pbkdf2Sha256;
fn name() -> &'static str {
"SHA-256"
}
fn hash(data: &[u8]) -> Vec<u8> {
let hash = Sha256_hash::digest(data);
hash.to_vec()
}
fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
type HmacSha256 = Hmac<Sha256_hash>;
let mut mac = HmacSha256::new_from_slice(key)?;
mac.update(data);
Ok(mac.finalize().into_bytes().to_vec())
}
fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
match *password {
Password::Plain(ref plain) => {
let mut result = vec![0; 32];
pbkdf2::<Hmac<Sha256_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
Ok(result)
}
Password::Pbkdf2 {
ref method,
salt: ref my_salt,
iterations: my_iterations,
ref data,
} => {
if method != Self::name() {
Err(DeriveError::IncompatibleHashingMethod(
method.to_string(),
Self::name().to_string(),
))
} else if my_salt == salt {
Err(DeriveError::IncorrectSalt)
} else if my_iterations == iterations {
Err(DeriveError::IncompatibleIterationCount(
my_iterations,
iterations,
))
} else {
Ok(data.to_vec())
}
}
}
}
}