@@ -33,25 +33,26 @@ fn normalize(pass: &[u8]) -> Vec<u8> {
33
33
}
34
34
}
35
35
36
- fn hi ( str : & [ u8 ] , salt : & [ u8 ] , i : u32 ) -> GenericArray < u8 , U32 > {
37
- let mut hmac = Hmac :: < Sha256 > :: new ( str) ;
36
+ fn hi ( str : & [ u8 ] , salt : & [ u8 ] , i : u32 ) -> io:: Result < GenericArray < u8 , U32 > > {
37
+ let mut hmac = Hmac :: < Sha256 > :: new ( str)
38
+ . map_err ( |_| invalid_key_length_error ( ) ) ?;
38
39
hmac. input ( salt) ;
39
40
hmac. input ( & [ 0 , 0 , 0 , 1 ] ) ;
40
- let mut prev = hmac. result ( ) ;
41
+ let mut prev = hmac. result ( ) . code ( ) ;
41
42
42
- let mut hi = GenericArray :: < u8 , U32 > :: clone_from_slice ( prev. code ( ) ) ;
43
+ let mut hi = GenericArray :: < u8 , U32 > :: clone_from_slice ( & prev) ;
43
44
44
45
for _ in 1 ..i {
45
- let mut hmac = Hmac :: < Sha256 > :: new ( str) ;
46
- hmac. input ( prev. code ( ) ) ;
47
- prev = hmac. result ( ) ;
46
+ let mut hmac = Hmac :: < Sha256 > :: new ( str) . expect ( "already checked above" ) ;
47
+ hmac. input ( prev. as_slice ( ) ) ;
48
+ prev = hmac. result ( ) . code ( ) ;
48
49
49
- for ( hi, prev) in hi. iter_mut ( ) . zip ( prev. code ( ) ) {
50
- * hi ^= * prev;
50
+ for ( hi, prev) in hi. iter_mut ( ) . zip ( prev) {
51
+ * hi ^= prev;
51
52
}
52
53
}
53
54
54
- hi
55
+ Ok ( hi )
55
56
}
56
57
57
58
enum State {
@@ -148,28 +149,30 @@ impl ScramSha256 {
148
149
Err ( e) => return Err ( io:: Error :: new ( io:: ErrorKind :: InvalidInput , e) ) ,
149
150
} ;
150
151
151
- let salted_password = hi ( & password, & salt, parsed. iteration_count ) ;
152
+ let salted_password = hi ( & password, & salt, parsed. iteration_count ) ? ;
152
153
153
- let mut hmac = Hmac :: < Sha256 > :: new ( & salted_password) ;
154
+ let mut hmac = Hmac :: < Sha256 > :: new ( & salted_password)
155
+ . map_err ( |_| invalid_key_length_error ( ) ) ?;
154
156
hmac. input ( b"Client Key" ) ;
155
- let client_key = hmac. result ( ) ;
157
+ let client_key = hmac. result ( ) . code ( ) ;
156
158
157
159
let mut hash = Sha256 :: default ( ) ;
158
- hash. input ( client_key. code ( ) ) ;
160
+ hash. input ( client_key. as_slice ( ) ) ;
159
161
let stored_key = hash. result ( ) ;
160
162
161
163
self . message . clear ( ) ;
162
164
write ! ( & mut self . message, "c=biws,r={}" , parsed. nonce) . unwrap ( ) ;
163
165
164
166
let auth_message = format ! ( "n=,r={},{},{}" , client_nonce, message, self . message) ;
165
167
166
- let mut hmac = Hmac :: < Sha256 > :: new ( & stored_key) ;
168
+ let mut hmac = Hmac :: < Sha256 > :: new ( & stored_key)
169
+ . map_err ( |_| invalid_key_length_error ( ) ) ?;
167
170
hmac. input ( auth_message. as_bytes ( ) ) ;
168
171
let client_signature = hmac. result ( ) ;
169
172
170
- let mut client_proof = GenericArray :: < u8 , U32 > :: clone_from_slice ( client_key. code ( ) ) ;
173
+ let mut client_proof = GenericArray :: < u8 , U32 > :: clone_from_slice ( & client_key) ;
171
174
for ( proof, signature) in client_proof. iter_mut ( ) . zip ( client_signature. code ( ) ) {
172
- * proof ^= * signature;
175
+ * proof ^= signature;
173
176
}
174
177
175
178
write ! ( & mut self . message, ",p={}" , base64:: encode( & * client_proof) ) . unwrap ( ) ;
@@ -215,20 +218,18 @@ impl ScramSha256 {
215
218
Err ( e) => return Err ( io:: Error :: new ( io:: ErrorKind :: InvalidInput , e) ) ,
216
219
} ;
217
220
218
- let mut hmac = Hmac :: < Sha256 > :: new ( & salted_password) ;
221
+ let mut hmac = Hmac :: < Sha256 > :: new ( & salted_password)
222
+ . map_err ( |_| invalid_key_length_error ( ) ) ?;
219
223
hmac. input ( b"Server Key" ) ;
220
224
let server_key = hmac. result ( ) ;
221
225
222
- let mut hmac = Hmac :: < Sha256 > :: new ( server_key. code ( ) ) ;
226
+ let mut hmac = Hmac :: < Sha256 > :: new ( & server_key. code ( ) )
227
+ . map_err ( |_| invalid_key_length_error ( ) ) ?;
223
228
hmac. input ( auth_message. as_bytes ( ) ) ;
224
- if hmac. verify ( & verifier) {
225
- Ok ( ( ) )
226
- } else {
227
- Err ( io:: Error :: new (
228
- io:: ErrorKind :: InvalidInput ,
229
- "SCRAM verification error" ,
230
- ) )
231
- }
229
+ hmac. verify ( & verifier) . map_err ( |_| io:: Error :: new (
230
+ io:: ErrorKind :: InvalidInput ,
231
+ "SCRAM verification error" ,
232
+ ) )
232
233
}
233
234
}
234
235
@@ -398,6 +399,10 @@ enum ServerFinalMessage<'a> {
398
399
Verifier ( & ' a str ) ,
399
400
}
400
401
402
+ fn invalid_key_length_error ( ) -> io:: Error {
403
+ io:: Error :: new ( io:: ErrorKind :: InvalidInput , "invalid key length" )
404
+ }
405
+
401
406
#[ cfg( test) ]
402
407
mod test {
403
408
use super :: * ;
0 commit comments