Skip to content

Commit aa4c60f

Browse files
Paweł Andruszkiewiczrennox
Paweł Andruszkiewicz
authored andcommitted
BUG#35468541 HEAD request is not retried if it fails with an authorization error
If an AWS HEAD request failed with an authorization error, it was not retried, as response did not contain body, and retry mechanism uses body to recognize AWS authorization errors. As a workaround, whenever an AWS HEAD request fails with a 400 HTTP error code, it is retried. Additionally, the expiration time of AWS credentials (if present) is adjusted, so that credentials are refreshed earlier, five minutes before the required time. Change-Id: I0a622383e01258284ac1f05e8d30f3f956d01e68
1 parent 64189d3 commit aa4c60f

File tree

10 files changed

+99
-31
lines changed

10 files changed

+99
-31
lines changed

mysqlshdk/libs/aws/aws_credentials_provider.cc

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, Oracle and/or its affiliates.
2+
* Copyright (c) 2022, 2023, Oracle and/or its affiliates.
33
*
44
* This program is free software; you can redistribute it and/or modify
55
* it under the terms of the GNU General Public License, version 2.0,
@@ -23,10 +23,14 @@
2323

2424
#include "mysqlshdk/libs/aws/aws_credentials_provider.h"
2525

26+
#include <chrono>
2627
#include <mutex>
2728
#include <stdexcept>
2829
#include <utility>
2930

31+
#include "mysqlshdk/libs/utils/logger.h"
32+
#include "mysqlshdk/libs/utils/utils_time.h"
33+
3034
namespace mysqlshdk {
3135
namespace aws {
3236

@@ -96,10 +100,33 @@ std::shared_ptr<Aws_credentials> Aws_credentials_provider::get_credentials() {
96100

97101
if (credentials.access_key_id.has_value() &&
98102
credentials.secret_access_key.has_value()) {
103+
Aws_credentials::Time_point expiration = Aws_credentials::NO_EXPIRATION;
104+
105+
if (credentials.expiration.has_value()) {
106+
log_info("The AWS credentials are set to expire at: %s",
107+
credentials.expiration->c_str());
108+
109+
try {
110+
expiration = shcore::rfc3339_to_time_point(*credentials.expiration);
111+
} catch (const std::exception &e) {
112+
throw std::runtime_error("Failed to parse 'Expiration' value '" +
113+
*credentials.expiration + "' in " + name());
114+
}
115+
116+
// we adjust the expiration time so credentials are refreshed before they
117+
// actually expire; the minimum duration specified by AWS docs is 15
118+
// minutes, but we check if we don't end up with a value that's in the
119+
// past, as some of our tests use values smaller than that
120+
if (const auto adjusted_expiration = expiration - std::chrono::minutes(5);
121+
Aws_credentials::Clock::now() < adjusted_expiration) {
122+
log_info("The expiration of AWS credentials has been adjusted");
123+
expiration = adjusted_expiration;
124+
}
125+
}
126+
99127
return std::make_shared<Aws_credentials>(
100128
*credentials.access_key_id, *credentials.secret_access_key,
101-
credentials.session_token.value_or(std::string{}),
102-
credentials.expiration);
129+
credentials.session_token.value_or(std::string{}), expiration);
103130
}
104131

105132
return {};

mysqlshdk/libs/aws/aws_credentials_provider.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, Oracle and/or its affiliates.
2+
* Copyright (c) 2022, 2023, Oracle and/or its affiliates.
33
*
44
* This program is free software; you can redistribute it and/or modify
55
* it under the terms of the GNU General Public License, version 2.0,
@@ -63,7 +63,7 @@ class Aws_credentials_provider {
6363
std::optional<std::string> access_key_id;
6464
std::optional<std::string> secret_access_key;
6565
std::optional<std::string> session_token;
66-
Aws_credentials::Time_point expiration = Aws_credentials::NO_EXPIRATION;
66+
std::optional<std::string> expiration;
6767
};
6868

6969
struct Context {

mysqlshdk/libs/aws/aws_signer.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,15 @@ bool Aws_signer::auth_data_expired(time_t now) const {
271271
return m_credentials->expired(now);
272272
}
273273

274-
bool Aws_signer::is_authorization_error(const rest::Response &response) const {
274+
bool Aws_signer::is_authorization_error(const rest::Signed_request &request,
275+
const rest::Response &response) const {
275276
if (rest::Response::Status_code::BAD_REQUEST == response.status) {
277+
if (rest::Type::HEAD == request.type) {
278+
// if this was a HEAD request, then we won't get the body and the error
279+
// code, retry just in case, it's a lightweight request
280+
return true;
281+
}
282+
276283
if (const auto error = response.get_error(); error.has_value()) {
277284
if ("ExpiredToken" == error->code() ||
278285
"TokenRefreshRequired" == error->code()) {
@@ -283,7 +290,7 @@ bool Aws_signer::is_authorization_error(const rest::Response &response) const {
283290
}
284291
}
285292

286-
return Signer::is_authorization_error(response);
293+
return Signer::is_authorization_error(request, response);
287294
}
288295

289296
bool Aws_signer::update_credentials() {

mysqlshdk/libs/aws/aws_signer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ class Aws_signer : public rest::Signer {
6868

6969
bool auth_data_expired(time_t now) const override;
7070

71-
bool is_authorization_error(const rest::Response &response) const override;
71+
bool is_authorization_error(const rest::Signed_request &request,
72+
const rest::Response &response) const override;
7273

7374
private:
7475
#ifdef FRIEND_TEST

mysqlshdk/libs/aws/process_credentials_provider.cc

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, Oracle and/or its affiliates.
2+
* Copyright (c) 2022, 2023, Oracle and/or its affiliates.
33
*
44
* This program is free software; you can redistribute it and/or modify
55
* it under the terms of the GNU General Public License, version 2.0,
@@ -35,7 +35,6 @@
3535
#include "mysqlshdk/libs/utils/utils_general.h"
3636
#include "mysqlshdk/libs/utils/utils_lexing.h"
3737
#include "mysqlshdk/libs/utils/utils_string.h"
38-
#include "mysqlshdk/libs/utils/utils_time.h"
3938

4039
#ifdef _WIN32
4140
#define pclose _pclose
@@ -235,15 +234,7 @@ Aws_credentials_provider::Credentials Process_credentials_provider::parse_json(
235234
creds.access_key_id = required(access_key_id());
236235
creds.secret_access_key = required(secret_access_key());
237236
creds.session_token = optional("SessionToken");
238-
239-
if (const auto expiration = optional("Expiration"); expiration.has_value()) {
240-
try {
241-
creds.expiration = shcore::rfc3339_to_time_point(*expiration);
242-
} catch (const std::exception &e) {
243-
handle_error(std::string{"failed to parse 'Expiration' value: "} +
244-
e.what());
245-
}
246-
}
237+
creds.expiration = optional("Expiration");
247238

248239
return creds;
249240
}

mysqlshdk/libs/rest/signed_rest_service.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ Response::Status_code Signed_rest_service::execute(Signed_request *request,
243243
do {
244244
code = rest->execute(request, response);
245245
retry = Response::is_error(code) &&
246-
m_signer->is_authorization_error(*response) &&
246+
m_signer->is_authorization_error(*request, *response) &&
247247
++retries <= k_authorization_retry_limit;
248248

249249
if (retry) {

mysqlshdk/libs/rest/signed_rest_service.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ class Signer {
7777

7878
virtual bool auth_data_expired(time_t now) const = 0;
7979

80-
virtual bool is_authorization_error(const Response &response) const {
80+
virtual bool is_authorization_error(const Signed_request &,
81+
const Response &response) const {
8182
return Response::Status_code::UNAUTHORIZED == response.status;
8283
}
8384
};

unittest/mysqlshdk/libs/aws/aws_credentials_provider_t.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, Oracle and/or its affiliates.
2+
* Copyright (c) 2022, 2023, Oracle and/or its affiliates.
33
*
44
* This program is free software; you can redistribute it and/or modify
55
* it under the terms of the GNU General Public License, version 2.0,
@@ -27,10 +27,12 @@
2727

2828
#include "unittest/gtest_clean.h"
2929

30+
#include <chrono>
3031
#include <thread>
3132
#include <vector>
3233

3334
#include "mysqlshdk/libs/utils/utils_general.h"
35+
#include "mysqlshdk/libs/utils/utils_time.h"
3436

3537
namespace mysqlshdk {
3638
namespace aws {
@@ -51,14 +53,30 @@ TEST(Aws_credentials_provider_test, temporary_credentials) {
5153

5254
result.access_key_id = access_key_id() + std::to_string(m_called);
5355
result.secret_access_key = secret_access_key() + std::to_string(m_called);
54-
result.expiration =
55-
Aws_credentials::Clock::now() + std::chrono::milliseconds(100);
56+
result.expiration = expiration();
5657

5758
++m_called;
5859

5960
return result;
6061
}
6162

63+
static std::string expiration() {
64+
using std::chrono::milliseconds;
65+
using std::chrono::time_point_cast;
66+
67+
const auto tp = time_point_cast<milliseconds>(
68+
Aws_credentials::Clock::now() + milliseconds(100));
69+
auto result = shcore::time_point_to_rfc3339(tp);
70+
if (std::string::npos == result.find('.')) {
71+
// add milliseconds
72+
auto ms = std::to_string(tp.time_since_epoch().count() %
73+
milliseconds::period::den);
74+
ms = "." + std::string(3 - ms.length(), '0') + ms;
75+
result = result.substr(0, 19) + ms + result.substr(19);
76+
}
77+
return result;
78+
}
79+
6280
int m_called = 0;
6381
};
6482

unittest/mysqlshdk/libs/aws/s3_bucket_config_t.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, Oracle and/or its affiliates.
2+
* Copyright (c) 2022, 2023, Oracle and/or its affiliates.
33
*
44
* This program is free software; you can redistribute it and/or modify
55
* it under the terms of the GNU General Public License, version 2.0,
@@ -875,12 +875,14 @@ TEST_F(Aws_s3_bucket_config_test, process_credentials) {
875875
}
876876

877877
const auto EXPECT_EXPIRATION = [](Aws_credentials::Time_point expected,
878-
Aws_credentials::Time_point actual) {
878+
Aws_credentials::Time_point actual,
879+
bool adjusted = false) {
879880
// shcore::time_point_to_rfc3339() does not include information about
880881
// fractions of seconds, these need to be converted to seconds before we can
881882
// compare them
882883
EXPECT_EQ(std::chrono::floor<std::chrono::seconds>(expected),
883-
std::chrono::floor<std::chrono::seconds>(actual));
884+
std::chrono::floor<std::chrono::seconds>(actual) +
885+
std::chrono::minutes(adjusted ? 5 : 0));
884886
};
885887

886888
{
@@ -947,7 +949,7 @@ TEST_F(Aws_s3_bucket_config_test, process_credentials) {
947949
EXPECT_EQ(k_session_token, credentials->session_token());
948950
EXPECT_TRUE(credentials->temporary());
949951
EXPECT_FALSE(credentials->expired());
950-
EXPECT_EXPIRATION(expiration, credentials->expiration());
952+
EXPECT_EXPIRATION(expiration, credentials->expiration(), true);
951953
}
952954

953955
{
@@ -969,7 +971,7 @@ TEST_F(Aws_s3_bucket_config_test, process_credentials) {
969971
EXPECT_EQ("", credentials->session_token());
970972
EXPECT_TRUE(credentials->temporary());
971973
EXPECT_FALSE(credentials->expired());
972-
EXPECT_EXPIRATION(expiration, credentials->expiration());
974+
EXPECT_EXPIRATION(expiration, credentials->expiration(), true);
973975
}
974976

975977
{

unittest/scripts/auto/py_aws/scripts/util_dump_and_load_aws_norecord.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,17 +516,38 @@ def TEST_BUG35027093(http_server, expected_error):
516516
EXPECT_SHELL_LOG_CONTAINS_COUNT("Refreshing authentication data", 2)
517517
EXPECT_SHELL_LOG_CONTAINS_COUNT("Retrying a request which failed due to an authorization error", 2)
518518

519+
# BUG#35468541 - expired HEAD request
520+
class FailHeadRequest(BaseHTTPRequestHandler):
521+
protocol_version = "HTTP/1.1"
522+
def do_HEAD(self):
523+
self.send_response(400)
524+
self.send_header('Content-Type', 'application/xml')
525+
self.send_header('Content-Length', 123)
526+
self.end_headers()
527+
self.close_connection = True
528+
def log_message(self, format, *args):
529+
pass
530+
531+
def TEST_BUG35468541(http_server, expected_error):
532+
WIPE_SHELL_LOG()
533+
EXPECT_THROWS(lambda: util.load_dump(dump_dir, get_options({ "s3EndpointOverride": test_server_url(http_server), "s3Profile": local_aws_profile, "s3ConfigFile": local_aws_config_file })), expected_error)
534+
EXPECT_SHELL_LOG_CONTAINS_COUNT("Refreshing authentication data", 2)
535+
EXPECT_SHELL_LOG_CONTAINS_COUNT("Retrying a request which failed due to an authorization error", 2)
536+
519537
expired_token_server = start_test_server(ExpiredToken)
520538
token_refresh_server = start_test_server(TokenRefreshRequired)
539+
fail_head_server = start_test_server(FailHeadRequest)
521540

522-
#@<> BUG#35027093 - test
541+
#@<> BUG#35027093 + BUG#35468541 - test
523542
with write_profile(local_aws_config_file, "profile " + local_aws_profile, { "credential_process": f"{__mysqlsh} --py --file {creds_script}" }):
524543
TEST_BUG35027093(expired_token_server, "The provided token has expired. (400)")
525544
TEST_BUG35027093(token_refresh_server, "The provided token must be refreshed. (400)")
545+
TEST_BUG35468541(fail_head_server, "Bad Request (400)")
526546

527-
#@<> BUG#35027093 - cleanup
547+
#@<> BUG#35027093 + BUG#35468541 - cleanup
528548
stop_test_server(expired_token_server)
529549
stop_test_server(token_refresh_server)
550+
stop_test_server(fail_head_server)
530551

531552
if os.path.exists(creds_script):
532553
os.remove(creds_script)

0 commit comments

Comments
 (0)