Skip to content

Commit 21e996f

Browse files
authored
Merge pull request #85 from gizmodata/main
Adding "requester pays" mode for S3 buckets
2 parents b17b647 + 2898da2 commit 21e996f

File tree

5 files changed

+90
-3
lines changed

5 files changed

+90
-3
lines changed

extension/httpfs/create_secret_functions.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ unique_ptr<BaseSecret> CreateS3SecretFunctions::CreateSecretFunctionInternal(Cli
108108
}
109109
refresh = true;
110110
secret->secret_map["refresh_info"] = MapToStruct(named_param.second);
111+
} else if (lower_name == "requester_pays") {
112+
if (named_param.second.type() != LogicalType::BOOLEAN) {
113+
throw InvalidInputException("Invalid type past to secret option: '%s', found '%s', expected: 'BOOLEAN'",
114+
lower_name, named_param.second.type().ToString());
115+
}
116+
secret->secret_map["requester_pays"] = Value::BOOLEAN(named_param.second.GetValue<bool>());
111117
} else {
112118
throw InvalidInputException("Unknown named parameter passed to CreateSecretFunctionInternal: " +
113119
lower_name);
@@ -185,6 +191,7 @@ void CreateS3SecretFunctions::SetBaseNamedParams(CreateSecretFunction &function,
185191
function.named_parameters["use_ssl"] = LogicalType::BOOLEAN;
186192
function.named_parameters["kms_key_id"] = LogicalType::VARCHAR;
187193
function.named_parameters["url_compatibility_mode"] = LogicalType::BOOLEAN;
194+
function.named_parameters["requester_pays"] = LogicalType::BOOLEAN;
188195

189196
// Whether a secret refresh attempt should be made when the secret appears to be incorrect
190197
function.named_parameters["refresh"] = LogicalType::VARCHAR;

extension/httpfs/httpfs_extension.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ static void LoadInternal(DatabaseInstance &instance) {
7070
config.AddExtensionOption("s3_kms_key_id", "S3 KMS Key ID", LogicalType::VARCHAR);
7171
config.AddExtensionOption("s3_url_compatibility_mode", "Disable Globs and Query Parameters on S3 URLs",
7272
LogicalType::BOOLEAN, Value(false));
73+
config.AddExtensionOption("s3_requester_pays", "S3 use requester pays mode", LogicalType::BOOLEAN, Value(false));
7374

7475
// S3 Uploader config
7576
config.AddExtensionOption("s3_uploader_max_filesize", "S3 Uploader max filesize (between 50GB and 5TB)",

extension/httpfs/include/s3fs.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct S3AuthParams {
3030
string url_style;
3131
bool use_ssl = true;
3232
bool s3_url_compatibility_mode = false;
33+
bool requester_pays = false;
3334

3435
static S3AuthParams ReadFrom(optional_ptr<FileOpener> opener, FileOpenerInfo &info);
3536
};
@@ -43,6 +44,8 @@ struct AWSEnvironmentCredentialsProvider {
4344
static constexpr const char *DUCKDB_ENDPOINT_ENV_VAR = "DUCKDB_S3_ENDPOINT";
4445
static constexpr const char *DUCKDB_USE_SSL_ENV_VAR = "DUCKDB_S3_USE_SSL";
4546
static constexpr const char *DUCKDB_KMS_KEY_ID_ENV_VAR = "DUCKDB_S3_KMS_KEY_ID";
47+
static constexpr const char *DUCKDB_REQUESTER_PAYS_ENV_VAR = "DUCKDB_S3_REQUESTER_PAYS";
48+
4649

4750
explicit AWSEnvironmentCredentialsProvider(DBConfig &config) : config(config) {};
4851

extension/httpfs/s3fs.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@ static HTTPHeaders create_s3_header(string url, string query, string host, strin
6262
res["x-amz-server-side-encryption-aws-kms-key-id"] = auth_params.kms_key_id;
6363
}
6464

65-
string signed_headers = "";
65+
bool use_requester_pays = auth_params.requester_pays;
66+
if (use_requester_pays) {
67+
res["x-amz-request-payer"] = "requester";
68+
}
69+
70+
string signed_headers = "";
6671
hash_bytes canonical_request_hash;
6772
hash_str canonical_request_hash_str;
6873
if (content_type.length() > 0) {
@@ -75,7 +80,10 @@ static HTTPHeaders create_s3_header(string url, string query, string host, strin
7580
if (use_sse_kms) {
7681
signed_headers += ";x-amz-server-side-encryption;x-amz-server-side-encryption-aws-kms-key-id";
7782
}
78-
auto canonical_request = method + "\n" + S3FileSystem::UrlEncode(url) + "\n" + query;
83+
if (use_requester_pays) {
84+
signed_headers += ";x-amz-request-payer";
85+
}
86+
auto canonical_request = method + "\n" + S3FileSystem::UrlEncode(url) + "\n" + query;
7987
if (content_type.length() > 0) {
8088
canonical_request += "\ncontent-type:" + content_type;
8189
}
@@ -87,6 +95,9 @@ static HTTPHeaders create_s3_header(string url, string query, string host, strin
8795
canonical_request += "\nx-amz-server-side-encryption:aws:kms";
8896
canonical_request += "\nx-amz-server-side-encryption-aws-kms-key-id:" + auth_params.kms_key_id;
8997
}
98+
if (use_requester_pays) {
99+
canonical_request += "\nx-amz-request-payer:requester";
100+
}
90101

91102
canonical_request += "\n\n" + signed_headers + "\n" + payload_hash;
92103
sha256(canonical_request.c_str(), canonical_request.length(), canonical_request_hash);
@@ -143,6 +154,7 @@ void AWSEnvironmentCredentialsProvider::SetAll() {
143154
this->SetExtensionOptionValue("s3_endpoint", DUCKDB_ENDPOINT_ENV_VAR);
144155
this->SetExtensionOptionValue("s3_use_ssl", DUCKDB_USE_SSL_ENV_VAR);
145156
this->SetExtensionOptionValue("s3_kms_key_id", DUCKDB_KMS_KEY_ID_ENV_VAR);
157+
this->SetExtensionOptionValue("s3_requester_pays", DUCKDB_REQUESTER_PAYS_ENV_VAR);
146158
}
147159

148160
S3AuthParams AWSEnvironmentCredentialsProvider::CreateParams() {
@@ -156,6 +168,7 @@ S3AuthParams AWSEnvironmentCredentialsProvider::CreateParams() {
156168
params.endpoint = DUCKDB_ENDPOINT_ENV_VAR;
157169
params.kms_key_id = DUCKDB_KMS_KEY_ID_ENV_VAR;
158170
params.use_ssl = DUCKDB_USE_SSL_ENV_VAR;
171+
params.requester_pays = DUCKDB_REQUESTER_PAYS_ENV_VAR;
159172

160173
return params;
161174
}
@@ -181,6 +194,8 @@ S3AuthParams S3AuthParams::ReadFrom(optional_ptr<FileOpener> opener, FileOpenerI
181194
secret_reader.TryGetSecretKeyOrSetting("kms_key_id", "s3_kms_key_id", result.kms_key_id);
182195
secret_reader.TryGetSecretKeyOrSetting("s3_url_compatibility_mode", "s3_url_compatibility_mode",
183196
result.s3_url_compatibility_mode);
197+
secret_reader.TryGetSecretKeyOrSetting("requester_pays", "s3_requester_pays",
198+
result.requester_pays);
184199

185200
// Endpoint and url style are slightly more complex and require special handling for gcs and r2
186201
auto endpoint_result = secret_reader.TryGetSecretKeyOrSetting("endpoint", "s3_endpoint", result.endpoint);
@@ -219,6 +234,7 @@ unique_ptr<KeyValueSecret> CreateSecret(vector<string> &prefix_paths_p, string &
219234
return_value->secret_map["use_ssl"] = params.use_ssl;
220235
return_value->secret_map["kms_key_id"] = params.kms_key_id;
221236
return_value->secret_map["s3_url_compatibility_mode"] = params.s3_url_compatibility_mode;
237+
return_value->secret_map["requester_pays"] = params.requester_pays;
222238

223239
//! Set redact keys
224240
return_value->redact_keys = {"secret", "session_token"};
@@ -531,9 +547,21 @@ void S3FileSystem::ReadQueryParams(const string &url_query_param, S3AuthParams &
531547
}
532548
query_params.erase(found_param);
533549
}
550+
auto found_requester_pays_param = query_params.find("s3_requester_pays");
551+
if (found_requester_pays_param != query_params.end()) {
552+
if (found_requester_pays_param->second == "true") {
553+
params.requester_pays = true;
554+
} else if (found_requester_pays_param->second == "false") {
555+
params.requester_pays = false;
556+
} else {
557+
throw IOException("Incorrect setting found for s3_requester_pays, allowed values are: 'true' or 'false'");
558+
}
559+
query_params.erase(found_requester_pays_param);
560+
}
534561
if (!query_params.empty()) {
535562
throw IOException("Invalid query parameters found. Supported parameters are:\n's3_region', 's3_access_key_id', "
536-
"'s3_secret_access_key', 's3_session_token',\n's3_endpoint', 's3_url_style', 's3_use_ssl'");
563+
"'s3_secret_access_key', 's3_session_token',\n's3_endpoint', 's3_url_style', 's3_use_ssl', "
564+
"'s3_requester_pays'");
537565
}
538566
}
539567

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# name: test/sql/secret/secret_aws_requester_pays.test
2+
# description: Tests secret refreshing with AWS requester pays mode
3+
# group: [secrets]
4+
5+
require-env S3_TEST_SERVER_AVAILABLE 1
6+
7+
require-env AWS_DEFAULT_REGION
8+
9+
require-env AWS_ACCESS_KEY_ID
10+
11+
require-env AWS_SECRET_ACCESS_KEY
12+
13+
require-env DUCKDB_S3_ENDPOINT
14+
15+
require-env DUCKDB_S3_USE_SSL
16+
17+
require httpfs
18+
19+
require parquet
20+
21+
statement ok
22+
SET enable_logging=true
23+
24+
statement ok
25+
set s3_use_ssl='${DUCKDB_S3_USE_SSL}'
26+
27+
statement ok
28+
set s3_endpoint='${DUCKDB_S3_ENDPOINT}'
29+
30+
statement ok
31+
set s3_region='${AWS_DEFAULT_REGION}'
32+
33+
# Create some test data
34+
statement ok
35+
CREATE SECRET s1 (
36+
TYPE S3,
37+
KEY_ID '${AWS_ACCESS_KEY_ID}',
38+
SECRET '${AWS_SECRET_ACCESS_KEY}',
39+
REQUESTER_PAYS true
40+
)
41+
42+
statement ok
43+
copy (select 1 as a) to 's3://test-bucket/test-file.parquet'
44+
45+
query I
46+
FROM "s3://test-bucket/test-file.parquet"
47+
----
48+
1

0 commit comments

Comments
 (0)