1818#include < boost/mysql/impl/internal/protocol/impl/protocol_types.hpp>
1919#include < boost/mysql/impl/internal/protocol/impl/serialization_context.hpp>
2020#include < boost/mysql/impl/internal/protocol/static_buffer.hpp>
21+ #include < boost/mysql/impl/internal/sansio/auth_plugin_common.hpp>
2122#include < boost/mysql/impl/internal/sansio/connection_state_data.hpp>
23+ #include < boost/mysql/impl/internal/sansio/csha2p_encrypt_password.hpp>
2224
25+ #include < boost/asio/ssl/error.hpp>
26+ #include < boost/container/small_vector.hpp>
2327#include < boost/core/span.hpp>
2428#include < boost/system/result.hpp>
29+ #include < boost/system/system_category.hpp>
2530
2631#include < array>
32+ #include < cstddef>
2733#include < cstdint>
2834#include < openssl/sha.h>
2935
@@ -35,59 +41,53 @@ namespace mysql {
3541namespace detail {
3642
3743// Constants
38- BOOST_INLINE_CONSTEXPR std::size_t csha2p_challenge_length = 20 ;
39- BOOST_INLINE_CONSTEXPR std::size_t csha2p_response_length = 32 ;
44+ BOOST_INLINE_CONSTEXPR std::size_t csha2p_hash_size = 32 ;
4045BOOST_INLINE_CONSTEXPR const char * csha2p_plugin_name = " caching_sha2_password" ;
46+ static_assert (csha2p_hash_size <= max_hash_size, " " );
47+ static_assert (csha2p_hash_size == SHA256_DIGEST_LENGTH, " Buffer size mismatch" );
4148
4249inline void csha2p_hash_password_impl (
4350 string_view password,
44- span<const std::uint8_t , csha2p_challenge_length> challenge ,
45- span<std::uint8_t , csha2p_response_length > output
51+ span<const std::uint8_t , scramble_size> scramble ,
52+ span<std::uint8_t , csha2p_hash_size > output
4653)
4754{
48- static_assert (csha2p_response_length == SHA256_DIGEST_LENGTH, " Buffer size mismatch" );
49-
50- // SHA(SHA(password_sha) concat challenge) XOR password_sha
55+ // SHA(SHA(password_sha) concat scramble) XOR password_sha
5156 // hash1 = SHA(pass)
52- std::array<std::uint8_t , csha2p_response_length > password_sha;
57+ std::array<std::uint8_t , csha2p_hash_size > password_sha;
5358 SHA256 (reinterpret_cast <const unsigned char *>(password.data ()), password.size (), password_sha.data ());
5459
55- // SHA(password_sha) concat challenge = buffer
56- std::array<std::uint8_t , csha2p_response_length + csha2p_challenge_length > buffer;
60+ // SHA(password_sha) concat scramble = buffer
61+ std::array<std::uint8_t , csha2p_hash_size + scramble_size > buffer;
5762 SHA256 (password_sha.data (), password_sha.size (), buffer.data ());
58- std::memcpy (buffer.data () + csha2p_response_length, challenge .data (), csha2p_challenge_length );
63+ std::memcpy (buffer.data () + csha2p_hash_size, scramble .data (), scramble. size () );
5964
60- // SHA(SHA(password_sha) concat challenge ) = SHA(buffer) = salted_password
61- std::array<std::uint8_t , csha2p_response_length > salted_password;
65+ // SHA(SHA(password_sha) concat scramble ) = SHA(buffer) = salted_password
66+ std::array<std::uint8_t , csha2p_hash_size > salted_password;
6267 SHA256 (buffer.data (), buffer.size (), salted_password.data ());
6368
6469 // salted_password XOR password_sha
65- for (unsigned i = 0 ; i < csha2p_response_length ; ++i)
70+ for (unsigned i = 0 ; i < csha2p_hash_size ; ++i)
6671 {
6772 output[i] = salted_password[i] ^ password_sha[i];
6873 }
6974}
7075
71- inline system::result< static_buffer<32 > > csha2p_hash_password (
76+ inline static_buffer<max_hash_size > csha2p_hash_password (
7277 string_view password,
73- span<const std::uint8_t > challenge
78+ span<const std::uint8_t , scramble_size> scramble
7479)
7580{
76- // If the challenge doesn't match the expected size,
77- // something wrong is going on and we should fail
78- if (challenge.size () != csha2p_challenge_length)
79- return client_errc::protocol_value_error;
80-
8181 // Empty passwords are not hashed
8282 if (password.empty ())
8383 return {};
8484
8585 // Run the algorithm
86- static_buffer<32 > res (csha2p_response_length );
86+ static_buffer<max_hash_size > res (csha2p_hash_size );
8787 csha2p_hash_password_impl (
8888 password,
89- span< const std:: uint8_t , csha2p_challenge_length>(challenge) ,
90- span<std::uint8_t , csha2p_response_length >(res.data (), csha2p_response_length )
89+ scramble ,
90+ span<std::uint8_t , csha2p_hash_size >(res.data (), csha2p_hash_size )
9191 );
9292 return res;
9393}
@@ -106,13 +106,32 @@ class csha2p_algo
106106 return server_data.size () == 1u && server_data[0 ] == 3 ;
107107 }
108108
109+ static next_action encrypt_password (
110+ connection_state_data& st,
111+ std::uint8_t & seqnum,
112+ string_view password,
113+ span<const std::uint8_t , scramble_size> scramble,
114+ span<const std::uint8_t > server_key
115+ )
116+ {
117+ container::small_vector<std::uint8_t , 512 > buff;
118+ auto ec = csha2p_encrypt_password (password, scramble, server_key, buff, asio::error::ssl_category);
119+ if (ec)
120+ return ec;
121+ return st.write (
122+ string_eof{string_view (reinterpret_cast <const char *>(buff.data ()), buff.size ())},
123+ seqnum
124+ );
125+ }
126+
109127public:
110128 csha2p_algo () = default ;
111129
112130 next_action resume (
113131 connection_state_data& st,
114132 span<const std::uint8_t > server_data,
115133 string_view password,
134+ span<const std::uint8_t , scramble_size> scramble,
116135 bool secure_channel,
117136 std::uint8_t & seqnum
118137 )
@@ -124,19 +143,34 @@ class csha2p_algo
124143 // or told us to read again because an OK packet or error packet is coming.
125144 if (is_perform_full_auth (server_data))
126145 {
127- // At this point, we don't support full auth over insecure channels
128- if (!secure_channel)
146+ if (secure_channel)
129147 {
130- return make_error_code (client_errc::auth_plugin_requires_ssl);
131- }
148+ // We should send a packet with just the password, as a NULL-terminated string
149+ BOOST_MYSQL_YIELD (resume_point_, 1 , st. write (string_null{password}, seqnum))
132150
133- // We should send a packet with just the password, as a NULL-terminated string
134- BOOST_MYSQL_YIELD (resume_point_, 1 , st.write (string_null{password}, seqnum))
151+ // The server shouldn't send us any more packets
152+ return error_code (client_errc::bad_handshake_packet_type);
153+ }
154+ else
155+ {
156+ // Request the server's public key
157+ BOOST_MYSQL_YIELD (resume_point_, 2 , st.write (int1{2 }, seqnum))
158+
159+ // Encrypt the password with the key we were given
160+ BOOST_MYSQL_YIELD (
161+ resume_point_,
162+ 3 ,
163+ encrypt_password (st, seqnum, password, scramble, server_data)
164+ )
165+
166+ // The server shouldn't send us any more packets
167+ return error_code (client_errc::bad_handshake_packet_type);
168+ }
135169 }
136170 else if (is_fast_auth_ok (server_data))
137171 {
138172 // We should wait for the server to send an OK or an error
139- BOOST_MYSQL_YIELD (resume_point_, 2 , st.read (seqnum))
173+ BOOST_MYSQL_YIELD (resume_point_, 4 , st.read (seqnum))
140174 }
141175 else
142176 {
0 commit comments