sasl/server/mechanisms/
scram.rs

1use alloc::borrow::ToOwned;
2use alloc::format;
3use alloc::string::{String, ToString};
4use alloc::vec::Vec;
5use core::marker::PhantomData;
6
7use base64::{engine::general_purpose::STANDARD as Base64, Engine};
8
9use crate::common::scram::{generate_nonce, ScramProvider};
10use crate::common::{parse_frame, xor, ChannelBinding, Identity};
11use crate::secret;
12use crate::secret::Pbkdf2Secret;
13use crate::server::{Mechanism, MechanismError, Provider, Response};
14
15enum ScramState {
16    Init,
17    SentChallenge {
18        initial_client_message: Vec<u8>,
19        initial_server_message: Vec<u8>,
20        gs2_header: Vec<u8>,
21        server_nonce: String,
22        identity: Identity,
23        salted_password: Vec<u8>,
24    },
25    Done,
26}
27
28pub struct Scram<S, P>
29where
30    S: ScramProvider,
31    P: Provider<S::Secret>,
32    S::Secret: secret::Pbkdf2Secret,
33{
34    name: String,
35    state: ScramState,
36    channel_binding: ChannelBinding,
37    provider: P,
38    _marker: PhantomData<S>,
39}
40
41impl<S, P> Scram<S, P>
42where
43    S: ScramProvider,
44    P: Provider<S::Secret>,
45    S::Secret: secret::Pbkdf2Secret,
46{
47    pub fn new(provider: P, channel_binding: ChannelBinding) -> Scram<S, P> {
48        Scram {
49            name: format!("SCRAM-{}", S::name()),
50            state: ScramState::Init,
51            channel_binding,
52            provider,
53            _marker: PhantomData,
54        }
55    }
56}
57
58impl<S, P> Mechanism for Scram<S, P>
59where
60    S: ScramProvider,
61    P: Provider<S::Secret>,
62    S::Secret: secret::Pbkdf2Secret,
63{
64    fn name(&self) -> &str {
65        &self.name
66    }
67
68    fn respond(&mut self, payload: &[u8]) -> Result<Response, MechanismError> {
69        let next_state;
70        let ret;
71        match self.state {
72            ScramState::Init => {
73                // TODO: really ugly, mostly because parse_frame takes a &[u8] and i don't
74                //       want to double validate utf-8
75                //
76                //       NEED TO CHANGE THIS THOUGH. IT'S AWFUL.
77                let mut commas = 0;
78                let mut idx = 0;
79                for &b in payload {
80                    idx += 1;
81                    if b == 0x2C {
82                        commas += 1;
83                        if commas >= 2 {
84                            break;
85                        }
86                    }
87                }
88                if commas < 2 {
89                    return Err(MechanismError::FailedToDecodeMessage);
90                }
91                let gs2_header = payload[..idx].to_vec();
92                let rest = payload[idx..].to_vec();
93                // TODO: process gs2 header properly, not this ugly stuff
94                match self.channel_binding {
95                    ChannelBinding::None | ChannelBinding::Unsupported => {
96                        // Not supported.
97                        if gs2_header[0] != 0x79 {
98                            // ord("y")
99                            return Err(MechanismError::ChannelBindingNotSupported);
100                        }
101                    }
102                    ref other => {
103                        // Supported.
104                        if gs2_header[0] == 0x79 {
105                            // ord("y")
106                            return Err(MechanismError::ChannelBindingIsSupported);
107                        } else if !other.supports("tls-unique") {
108                            // TODO: grab the data
109                            return Err(MechanismError::ChannelBindingMechanismIncorrect);
110                        }
111                    }
112                }
113                let frame =
114                    parse_frame(&rest).map_err(|_| MechanismError::CannotDecodeInitialMessage)?;
115                let username = frame.get(&'n').ok_or(MechanismError::NoUsername)?;
116                let identity = Identity::Username(username.to_owned());
117                let client_nonce = frame.get(&'r').ok_or(MechanismError::NoNonce)?;
118                let mut server_nonce = String::new();
119                server_nonce += client_nonce;
120                server_nonce +=
121                    &generate_nonce().map_err(|_| MechanismError::FailedToGenerateNonce)?;
122                let pbkdf2 = self.provider.provide(&identity)?;
123                let mut buf = Vec::new();
124                buf.extend(b"r=");
125                buf.extend(server_nonce.bytes());
126                buf.extend(b",s=");
127                buf.extend(Base64.encode(pbkdf2.salt()).bytes());
128                buf.extend(b",i=");
129                buf.extend(pbkdf2.iterations().to_string().bytes());
130                ret = Response::Proceed(buf.clone());
131                next_state = ScramState::SentChallenge {
132                    server_nonce,
133                    identity,
134                    salted_password: pbkdf2.digest().to_vec(),
135                    initial_client_message: rest,
136                    initial_server_message: buf,
137                    gs2_header,
138                };
139            }
140            ScramState::SentChallenge {
141                ref server_nonce,
142                ref identity,
143                ref salted_password,
144                ref gs2_header,
145                ref initial_client_message,
146                ref initial_server_message,
147            } => {
148                let frame =
149                    parse_frame(payload).map_err(|_| MechanismError::CannotDecodeResponse)?;
150                let mut cb_data: Vec<u8> = Vec::new();
151                cb_data.extend(gs2_header);
152                cb_data.extend(self.channel_binding.data());
153                let mut client_final_message_bare = Vec::new();
154                client_final_message_bare.extend(b"c=");
155                client_final_message_bare.extend(Base64.encode(&cb_data).bytes());
156                client_final_message_bare.extend(b",r=");
157                client_final_message_bare.extend(server_nonce.bytes());
158                let client_key = S::hmac(b"Client Key", salted_password)?;
159                let server_key = S::hmac(b"Server Key", salted_password)?;
160                let mut auth_message = Vec::new();
161                auth_message.extend(initial_client_message);
162                auth_message.extend(b",");
163                auth_message.extend(initial_server_message);
164                auth_message.extend(b",");
165                auth_message.extend(client_final_message_bare.clone());
166                let stored_key = S::hash(&client_key);
167                let client_signature = S::hmac(&auth_message, &stored_key)?;
168                let client_proof = xor(&client_key, &client_signature);
169                let sent_proof = frame.get(&'p').ok_or(MechanismError::NoProof)?;
170                let sent_proof = Base64
171                    .decode(sent_proof)
172                    .map_err(|_| MechanismError::CannotDecodeProof)?;
173                if client_proof != sent_proof {
174                    return Err(MechanismError::AuthenticationFailed);
175                }
176                let server_signature = S::hmac(&auth_message, &server_key)?;
177                let mut buf = Vec::new();
178                buf.extend(b"v=");
179                buf.extend(Base64.encode(server_signature).bytes());
180                ret = Response::Success(identity.clone(), buf);
181                next_state = ScramState::Done;
182            }
183            ScramState::Done => {
184                return Err(MechanismError::SaslSessionAlreadyOver);
185            }
186        }
187        self.state = next_state;
188        Ok(ret)
189    }
190}