Skip to content

Commit d5d0ee3

Browse files
committed
WL14658: Implement MFA (multi factor authentication)
1 parent fbc0ea0 commit d5d0ee3

File tree

8 files changed

+205
-1
lines changed

8 files changed

+205
-1
lines changed

cppconn/connection.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
#define OPT_HOSTNAME "hostName"
5151
#define OPT_USERNAME "userName"
5252
#define OPT_PASSWORD "password"
53+
#define OPT_PASSWORD1 "password1"
54+
#define OPT_PASSWORD2 "password2"
55+
#define OPT_PASSWORD3 "password3"
5356
#define OPT_PORT "port"
5457
#define OPT_SOCKET "socket"
5558
#define OPT_PIPE "pipe"

driver/mysql_connection.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,42 @@ void MySQL_Connection::init(ConnectOptionsMap & properties)
573573
} else {
574574
throw sql::InvalidArgumentException("No string value passed for password");
575575
}
576+
} else if (!it->first.compare(OPT_PASSWORD1)) {
577+
try {
578+
p_s = (it->second).get< sql::SQLString >();
579+
} catch (sql::InvalidArgumentException&) {
580+
throw sql::InvalidArgumentException("Wrong type passed for password1 expected sql::SQLString");
581+
}
582+
if (p_s) {
583+
int num = 1;
584+
proxy->options(sql::mysql::MYSQL_OPT_USER_PASSWORD, num, *p_s);
585+
} else {
586+
throw sql::InvalidArgumentException("No string value passed for password1");
587+
}
588+
} else if (!it->first.compare(OPT_PASSWORD2)) {
589+
try {
590+
p_s = (it->second).get< sql::SQLString >();
591+
} catch (sql::InvalidArgumentException&) {
592+
throw sql::InvalidArgumentException("Wrong type passed for password2 expected sql::SQLString");
593+
}
594+
if (p_s) {
595+
int num = 2;
596+
proxy->options(sql::mysql::MYSQL_OPT_USER_PASSWORD, num, *p_s);
597+
} else {
598+
throw sql::InvalidArgumentException("No string value passed for password2");
599+
}
600+
} else if (!it->first.compare(OPT_PASSWORD3)) {
601+
try {
602+
p_s = (it->second).get< sql::SQLString >();
603+
} catch (sql::InvalidArgumentException&) {
604+
throw sql::InvalidArgumentException("Wrong type passed for password3 expected sql::SQLString");
605+
}
606+
if (p_s) {
607+
int num = 3;
608+
proxy->options(sql::mysql::MYSQL_OPT_USER_PASSWORD, num, *p_s);
609+
} else {
610+
throw sql::InvalidArgumentException("No string value passed for password3");
611+
}
576612
} else if (!it->first.compare(OPT_PORT)) {
577613
try {
578614
p_i = (it->second).get< int >();

driver/mysql_connection_options.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ enum MySQL_Connection_Options
6868
MYSQL_OPT_TLS_CIPHERSUITES,
6969
MYSQL_OPT_COMPRESSION_ALGORITHMS,
7070
MYSQL_OPT_ZSTD_COMPRESSION_LEVEL,
71-
MYSQL_OPT_LOAD_DATA_LOCAL_DIR
71+
MYSQL_OPT_LOAD_DATA_LOCAL_DIR,
72+
MYSQL_OPT_USER_PASSWORD
7273
#else
7374
MYSQL_OPT_CONNECT_TIMEOUT, MYSQL_OPT_COMPRESS, MYSQL_OPT_NAMED_PIPE,
7475
MYSQL_INIT_COMMAND, MYSQL_READ_DEFAULT_FILE, MYSQL_READ_DEFAULT_GROUP,

driver/nativeapi/mysql_native_connection_wrapper.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ get_mysql_option(sql::mysql::MySQL_Connection_Options opt)
110110
case sql::mysql::MYSQL_OPT_COMPRESSION_ALGORITHMS: return ::MYSQL_OPT_COMPRESSION_ALGORITHMS;
111111
case sql::mysql::MYSQL_OPT_ZSTD_COMPRESSION_LEVEL: return ::MYSQL_OPT_ZSTD_COMPRESSION_LEVEL;
112112
case sql::mysql::MYSQL_OPT_LOAD_DATA_LOCAL_DIR: return ::MYSQL_OPT_LOAD_DATA_LOCAL_DIR;
113+
case sql::mysql::MYSQL_OPT_USER_PASSWORD: return ::MYSQL_OPT_USER_PASSWORD;
113114
#else
114115
case sql::mysql::MYSQL_OPT_SSL_VERIFY_SERVER_CERT: return ::MYSQL_OPT_SSL_VERIFY_SERVER_CERT;
115116
case sql::mysql::MYSQL_OPT_USE_REMOTE_CONNECTION: return ::MYSQL_OPT_USE_REMOTE_CONNECTION;
@@ -340,6 +341,7 @@ MySQL_NativeConnectionWrapper::options(::sql::mysql::MySQL_Connection_Options op
340341
my_bool dummy= option_val ? '\1' : '\0';
341342
return api->options(mysql, get_mysql_option(option), &dummy);
342343
}
344+
/* }}} */
343345

344346

345347
/* {{{ MySQL_NativeConnectionWrapper::options(int &) */
@@ -349,6 +351,7 @@ MySQL_NativeConnectionWrapper::options(::sql::mysql::MySQL_Connection_Options op
349351
{
350352
return api->options(mysql, get_mysql_option(option), &option_val);
351353
}
354+
/* }}} */
352355

353356

354357
/* {{{ MySQL_NativeConnectionWrapper::options(SQLString &, SQLString &) */
@@ -361,6 +364,16 @@ MySQL_NativeConnectionWrapper::options(::sql::mysql::MySQL_Connection_Options op
361364
/* }}} */
362365

363366

367+
/* {{{ MySQL_NativeConnectionWrapper::options(int &, SQLString &) */
368+
int
369+
MySQL_NativeConnectionWrapper::options(::sql::mysql::MySQL_Connection_Options option,
370+
const int &factor, const ::sql::SQLString &value)
371+
{
372+
return api->options(mysql, get_mysql_option(option), &factor, value.c_str());
373+
}
374+
/* }}} */
375+
376+
364377
/* {{{ MySQL_NativeConnectionWrapper::get_option() */
365378
int
366379
MySQL_NativeConnectionWrapper::get_option(::sql::mysql::MySQL_Connection_Options option, const void * value)

driver/nativeapi/mysql_native_connection_wrapper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ struct st_mysql* mysql;
127127
int options(::sql::mysql::MySQL_Connection_Options, const int &) override;
128128
int options(::sql::mysql::MySQL_Connection_Options,
129129
const ::sql::SQLString &, const ::sql::SQLString &) override;
130+
int options(::sql::mysql::MySQL_Connection_Options,
131+
const int &, const ::sql::SQLString &) override;
130132

131133
int get_option(::sql::mysql::MySQL_Connection_Options, const void * ) override;
132134
int get_option(::sql::mysql::MySQL_Connection_Options,

driver/nativeapi/native_connection_wrapper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ class NativeConnectionWrapper : public boost::noncopyable
127127
virtual int options(::sql::mysql::MySQL_Connection_Options,
128128
const ::sql::SQLString &,
129129
const ::sql::SQLString &) = 0;
130+
virtual int options(::sql::mysql::MySQL_Connection_Options,
131+
const int &,
132+
const ::sql::SQLString &) = 0;
130133

131134
virtual int get_option(::sql::mysql::MySQL_Connection_Options, const void *) = 0;
132135
virtual int get_option(::sql::mysql::MySQL_Connection_Options,

test/unit/classes/connection.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3741,5 +3741,143 @@ void connection::dns_srv()
37413741
}
37423742
}
37433743

3744+
void connection::mfa()
3745+
{
3746+
logMsg("connection::mfa - multi factor authentication");
3747+
3748+
try {
3749+
stmt->execute("UNINSTALL PLUGIN cleartext_plugin_server");
3750+
} catch (...) {
3751+
}
3752+
3753+
try {
3754+
stmt->execute("INSTALL PLUGIN cleartext_plugin_server SONAME 'auth_test_plugin.so'");
3755+
} catch (...) {
3756+
SKIP("Server doesn't support auth test plugin cleartext_plugin_server");
3757+
}
3758+
3759+
struct MFA_TEST_DATA
3760+
{
3761+
const char* user;
3762+
const char* pwd;
3763+
const char* pwd1;
3764+
const char* pwd2;
3765+
const char* pwd3;
3766+
bool succeed;
3767+
};
3768+
3769+
3770+
MFA_TEST_DATA test_data[] =
3771+
{
3772+
// user1 tests
3773+
{"user_1f", "pass1", nullptr, nullptr, nullptr, true },
3774+
{"user_1f", "pass1", "pass1", nullptr, nullptr, true },
3775+
{"user_1f", "badp1", "pass1", nullptr, nullptr, true },
3776+
{"user_1f", nullptr, "pass1", nullptr, nullptr, true },
3777+
3778+
{"user_1f", nullptr, nullptr, nullptr, nullptr, false },
3779+
{"user_1f", "badp1", "badp1", "pass1", nullptr, false },
3780+
{"user_1f", nullptr, nullptr, "pass1", nullptr, false },
3781+
{"user_1f", nullptr, nullptr, nullptr, "pass1", false },
3782+
3783+
// user2 tests
3784+
{"user_2f", "pass1", nullptr, "pass2", nullptr, true },
3785+
{"user_2f", "pass1", "pass1", "pass2", nullptr, true },
3786+
{"user_2f", "badp1", "pass1", "pass2", nullptr, true },
3787+
{"user_2f", nullptr, "pass1", "pass2", nullptr, true },
3788+
3789+
{"user_2f", "pass2", nullptr, "pass1", nullptr, false },
3790+
{"user_2f", "pass2", "pass2", "pass1", nullptr, false },
3791+
{"user_2f", "pass2", "badp2", "pass1", nullptr, false },
3792+
{"user_2f", nullptr, "pass2", "pass1", nullptr, false },
3793+
3794+
{"user_2f", "pass1", nullptr, nullptr, "pass2", false },
3795+
{"user_2f", "pass1", "pass1", nullptr, "pass2", false },
3796+
{"user_2f", "badp1", "pass1", nullptr, "pass2", false },
3797+
{"user_2f", nullptr, "pass1", nullptr, "pass2", false },
3798+
3799+
{"user_2f", "pass1", nullptr , "badp1", "pass2", false },
3800+
{"user_2f", "pass1", "pass1", "badp1", "pass2", false },
3801+
{"user_2f", "badp1", "pass1", "badp1", "pass2", false },
3802+
{"user_2f", nullptr , "pass1", "badp1", "pass2", false },
3803+
3804+
// user3 tests
3805+
{"user_3f", "pass1", nullptr , "pass2", "pass3", true },
3806+
{"user_3f", "pass1", "pass1", "pass2", "pass3", true },
3807+
{"user_3f", "badp1", "pass1", "pass2", "pass3", true },
3808+
{"user_3f", nullptr , "pass1", "pass2", "pass3", true },
3809+
3810+
{"user_3f", "pass1", nullptr , "pass3", "pass2", false },
3811+
{"user_3f", "pass1", "pass1", "pass3", "pass2", false },
3812+
{"user_3f", "badp1", "pass1", "pass3", "pass2", false },
3813+
{"user_3f", nullptr , "pass1", "pass3", "pass2", false },
3814+
3815+
{"user_3f", "pass3", nullptr , "badp1", "pass2", false },
3816+
{"user_3f", "pass3", "pass3", "badp1", "pass2", false },
3817+
{"user_3f", "pass3", "badp3", "badp1", "pass2", false },
3818+
{"user_3f", nullptr , "pass3", "badp1", "pass2", false },
3819+
3820+
{"user_3f", "pass1", nullptr , "pass2", "badp3", false },
3821+
{"user_3f", "pass1", "pass1", "pass2", "badp3", false },
3822+
{"user_3f", "badp1", "pass1", "pass2", "badp3", false },
3823+
{"user_3f", nullptr , "pass1", "pass2", "badp3", false },
3824+
3825+
};
3826+
3827+
3828+
stmt->execute("drop user if exists user_1f");
3829+
stmt->execute("drop user if exists user_2f");
3830+
stmt->execute("drop user if exists user_3f");
3831+
3832+
stmt->execute("create user user_1f IDENTIFIED WITH cleartext_plugin_server BY 'pass1'");
3833+
stmt->execute("create user user_2f IDENTIFIED WITH cleartext_plugin_server BY 'pass1' "
3834+
"AND IDENTIFIED WITH cleartext_plugin_server BY 'pass2'; ");
3835+
stmt->execute("create user user_3f IDENTIFIED WITH cleartext_plugin_server by 'pass1' "
3836+
"AND IDENTIFIED WITH cleartext_plugin_server BY 'pass2' "
3837+
"AND IDENTIFIED WITH cleartext_plugin_server BY 'pass3'; ");
3838+
3839+
3840+
auto check_connection = [this] (sql::Connection* conn) -> void
3841+
{
3842+
std::unique_ptr<sql::Statement> my_stmt(conn->createStatement());
3843+
std::unique_ptr<sql::ResultSet> my_res(my_stmt->executeQuery("select @@version"));
3844+
my_res->next();
3845+
std::string version = my_res->getString(1);
3846+
3847+
logMsg(std::string("Server Version ")+version);
3848+
3849+
delete conn;
3850+
};
3851+
3852+
for(auto &data : test_data)
3853+
{
3854+
sql::ConnectOptionsMap opt;
3855+
opt[OPT_ENABLE_CLEARTEXT_PLUGIN] = true;
3856+
opt[OPT_USERNAME] = data.user;
3857+
if(data.pwd)
3858+
opt[OPT_PASSWORD] = data.pwd;
3859+
if(data.pwd1)
3860+
opt[OPT_PASSWORD1] = data.pwd1;
3861+
if(data.pwd2)
3862+
opt[OPT_PASSWORD2] = data.pwd2;
3863+
if(data.pwd3)
3864+
opt[OPT_PASSWORD3] = data.pwd3;
3865+
3866+
if(data.succeed)
3867+
{
3868+
check_connection(getConnection(&opt));
3869+
}
3870+
else
3871+
{
3872+
try {
3873+
getConnection(&opt);
3874+
FAIL("Should fail to connect");
3875+
} catch (sql::SQLException&) {
3876+
}
3877+
}
3878+
}
3879+
3880+
}
3881+
37443882
} /* namespace connection */
37453883
} /* namespace testsuite */

test/unit/classes/connection.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class connection : public unit_fixture
9191
TEST_CASE(cached_sha2_auth);
9292
TEST_CASE(socket);
9393
TEST_CASE(dns_srv);
94+
TEST_CASE(mfa);
9495
}
9596

9697
/**
@@ -283,6 +284,13 @@ class connection : public unit_fixture
283284
*/
284285
void dns_srv();
285286

287+
/*
288+
* Test of MySQL_Connection::mfa()
289+
*
290+
*/
291+
void mfa();
292+
293+
286294
};
287295

288296

0 commit comments

Comments
 (0)