@@ -4,7 +4,7 @@ use base64;
4
4
use generic_array:: typenum:: U32 ;
5
5
use generic_array:: GenericArray ;
6
6
use hmac:: { Hmac , Mac } ;
7
- use rand:: { OsRng , Rng } ;
7
+ use rand:: { self , Rng } ;
8
8
use sha2:: { Digest , Sha256 } ;
9
9
use std:: fmt:: Write ;
10
10
use std:: io;
@@ -17,6 +17,8 @@ const NONCE_LENGTH: usize = 24;
17
17
18
18
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
19
19
pub const SCRAM_SHA_256 : & ' static str = "SCRAM-SHA-256" ;
20
+ /// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
21
+ pub const SCRAM_SHA_256_PLUS : & ' static str = "SCRAM-SHA-256-PLUS" ;
20
22
21
23
// since postgres passwords are not required to exclude saslprep-prohibited
22
24
// characters or even be valid UTF8, we run saslprep if possible and otherwise
@@ -54,10 +56,61 @@ fn hi(str: &[u8], salt: &[u8], i: u32) -> GenericArray<u8, U32> {
54
56
hi
55
57
}
56
58
59
+ enum ChannelBindingInner {
60
+ Unrequested ,
61
+ Unsupported ,
62
+ TlsUnique ( Vec < u8 > ) ,
63
+ TlsServerEndPoint ( Vec < u8 > ) ,
64
+ }
65
+
66
+ /// The channel binding configuration for a SCRAM authentication exchange.
67
+ pub struct ChannelBinding ( ChannelBindingInner ) ;
68
+
69
+ impl ChannelBinding {
70
+ /// The server did not request channel binding.
71
+ pub fn unrequested ( ) -> ChannelBinding {
72
+ ChannelBinding ( ChannelBindingInner :: Unrequested )
73
+ }
74
+
75
+ /// The server requested channel binding but the client is unable to provide it.
76
+ pub fn unsupported ( ) -> ChannelBinding {
77
+ ChannelBinding ( ChannelBindingInner :: Unsupported )
78
+ }
79
+
80
+ /// The server requested channel binding and the client will use the `tls-unique` method.
81
+ pub fn tls_unique ( finished : Vec < u8 > ) -> ChannelBinding {
82
+ ChannelBinding ( ChannelBindingInner :: TlsUnique ( finished) )
83
+ }
84
+
85
+ /// The server requested channel binding and the client will use the `tls-server-end-point`
86
+ /// method.
87
+ pub fn tls_server_end_point ( signature : Vec < u8 > ) -> ChannelBinding {
88
+ ChannelBinding ( ChannelBindingInner :: TlsServerEndPoint ( signature) )
89
+ }
90
+
91
+ fn gs2_header ( & self ) -> & ' static str {
92
+ match self . 0 {
93
+ ChannelBindingInner :: Unrequested => "y,," ,
94
+ ChannelBindingInner :: Unsupported => "n,," ,
95
+ ChannelBindingInner :: TlsUnique ( _) => "p=tls-unique,," ,
96
+ ChannelBindingInner :: TlsServerEndPoint ( _) => "p=tls-server-end-point,," ,
97
+ }
98
+ }
99
+
100
+ fn cbind_data ( & self ) -> & [ u8 ] {
101
+ match self . 0 {
102
+ ChannelBindingInner :: Unrequested | ChannelBindingInner :: Unsupported => & [ ] ,
103
+ ChannelBindingInner :: TlsUnique ( ref buf)
104
+ | ChannelBindingInner :: TlsServerEndPoint ( ref buf) => buf,
105
+ }
106
+ }
107
+ }
108
+
57
109
enum State {
58
110
Update {
59
111
nonce : String ,
60
112
password : Vec < u8 > ,
113
+ channel_binding : ChannelBinding ,
61
114
} ,
62
115
Finish {
63
116
salted_password : GenericArray < u8 , U32 > ,
@@ -66,7 +119,8 @@ enum State {
66
119
Done ,
67
120
}
68
121
69
- /// A type which handles the client side of the SCRAM-SHA-256 authentication process.
122
+ /// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
123
+ /// process.
70
124
///
71
125
/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
72
126
/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
@@ -85,11 +139,11 @@ pub struct ScramSha256 {
85
139
state : State ,
86
140
}
87
141
88
- #[ allow( missing_docs) ]
89
142
impl ScramSha256 {
90
143
/// Constructs a new instance which will use the provided password for authentication.
91
- pub fn new ( password : & [ u8 ] ) -> io:: Result < ScramSha256 > {
92
- let mut rng = OsRng :: new ( ) ?;
144
+ pub fn new ( password : & [ u8 ] , channel_binding : ChannelBinding ) -> io:: Result < ScramSha256 > {
145
+ // rand 0.5's ThreadRng is cryptographically secure
146
+ let mut rng = rand:: thread_rng ( ) ;
93
147
let nonce = ( 0 ..NONCE_LENGTH )
94
148
. map ( |_| {
95
149
let mut v = rng. gen_range ( 0x21u8 , 0x7e ) ;
@@ -100,21 +154,20 @@ impl ScramSha256 {
100
154
} )
101
155
. collect :: < String > ( ) ;
102
156
103
- ScramSha256 :: new_inner ( password, nonce)
157
+ ScramSha256 :: new_inner ( password, channel_binding , nonce)
104
158
}
105
159
106
- fn new_inner ( password : & [ u8 ] , nonce : String ) -> io:: Result < ScramSha256 > {
107
- // the docs say to use pg_same_as_startup_message as the username, but
108
- // psql uses an empty string, so we'll go with that.
109
- let message = format ! ( "n,,n=,r={}" , nonce) ;
110
-
111
- let password = normalize ( password) ;
112
-
160
+ fn new_inner (
161
+ password : & [ u8 ] ,
162
+ channel_binding : ChannelBinding ,
163
+ nonce : String ,
164
+ ) -> io:: Result < ScramSha256 > {
113
165
Ok ( ScramSha256 {
114
- message : message ,
166
+ message : format ! ( "{}n=,r={}" , channel_binding . gs2_header ( ) , nonce ) ,
115
167
state : State :: Update {
116
- nonce : nonce,
117
- password : password,
168
+ nonce,
169
+ password : normalize ( password) ,
170
+ channel_binding : channel_binding,
118
171
} ,
119
172
} )
120
173
}
@@ -131,10 +184,15 @@ impl ScramSha256 {
131
184
///
132
185
/// This should be called when an `AuthenticationSASLContinue` message is received.
133
186
pub fn update ( & mut self , message : & [ u8 ] ) -> io:: Result < ( ) > {
134
- let ( client_nonce, password) = match mem:: replace ( & mut self . state , State :: Done ) {
135
- State :: Update { nonce, password } => ( nonce, password) ,
136
- _ => return Err ( io:: Error :: new ( io:: ErrorKind :: Other , "invalid SCRAM state" ) ) ,
137
- } ;
187
+ let ( client_nonce, password, channel_binding) =
188
+ match mem:: replace ( & mut self . state , State :: Done ) {
189
+ State :: Update {
190
+ nonce,
191
+ password,
192
+ channel_binding,
193
+ } => ( nonce, password, channel_binding) ,
194
+ _ => return Err ( io:: Error :: new ( io:: ErrorKind :: Other , "invalid SCRAM state" ) ) ,
195
+ } ;
138
196
139
197
let message =
140
198
str:: from_utf8 ( message) . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidInput , e) ) ?;
@@ -161,8 +219,13 @@ impl ScramSha256 {
161
219
hash. input ( client_key. as_slice ( ) ) ;
162
220
let stored_key = hash. result ( ) ;
163
221
222
+ let mut cbind_input = vec ! [ ] ;
223
+ cbind_input. extend ( channel_binding. gs2_header ( ) . as_bytes ( ) ) ;
224
+ cbind_input. extend ( channel_binding. cbind_data ( ) ) ;
225
+ let cbind_input = base64:: encode ( & cbind_input) ;
226
+
164
227
self . message . clear ( ) ;
165
- write ! ( & mut self . message, "c=biws ,r={}" , parsed. nonce) . unwrap ( ) ;
228
+ write ! ( & mut self . message, "c={} ,r={}" , cbind_input , parsed. nonce) . unwrap ( ) ;
166
229
167
230
let auth_message = format ! ( "n=,r={},{},{}" , client_nonce, message, self . message) ;
168
231
@@ -420,7 +483,11 @@ mod test {
420
483
1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
421
484
let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=" ;
422
485
423
- let mut scram = ScramSha256 :: new_inner ( password. as_bytes ( ) , nonce. to_string ( ) ) . unwrap ( ) ;
486
+ let mut scram = ScramSha256 :: new_inner (
487
+ password. as_bytes ( ) ,
488
+ ChannelBinding :: unsupported ( ) ,
489
+ nonce. to_string ( ) ,
490
+ ) . unwrap ( ) ;
424
491
assert_eq ! ( str :: from_utf8( scram. message( ) ) . unwrap( ) , client_first) ;
425
492
426
493
scram. update ( server_first. as_bytes ( ) ) . unwrap ( ) ;
0 commit comments