Skip to content

Commit 12f32be

Browse files
authored
✨ Added a [encoding] module to help encoding/decoding strings (#643)
<!-- Copyright (C) 2020-2022 Arm Limited or its affiliates and Contributors. All rights reserved. SPDX-License-Identifier: Apache-2.0 --> ### Description - Added helpers for performing base64 encoding ### Test Coverage <!-- Please put an `x` in the correct box e.g. `[x]` to indicate the testing coverage of this change. --> - [x] This change is covered by existing or additional automated tests. - [ ] Manual testing has been performed (and evidence provided) as automated testing was not feasible. - [ ] Additional tests are not required for this change (e.g. documentation update).
1 parent b3e9ed4 commit 12f32be

File tree

7 files changed

+294
-28
lines changed

7 files changed

+294
-28
lines changed

.secrets.baseline

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,5 +272,5 @@
272272
}
273273
]
274274
},
275-
"generated_at": "2025-06-27T16:00:59Z"
275+
"generated_at": "2025-07-09T20:27:15Z"
276276
}

changes/20250709190241.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:sparkles: Added a `[encoding]` module to help encode/decode strings

utils/encoding/base64/decode.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package base64
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"strings"
7+
8+
"github.com/go-ozzo/ozzo-validation/v4/is"
9+
10+
"github.com/ARM-software/golang-utils/utils/commonerrors"
11+
"github.com/ARM-software/golang-utils/utils/parallelisation"
12+
"github.com/ARM-software/golang-utils/utils/reflection"
13+
)
14+
15+
// DecodeString decodes a base64 encoded string. An error is raised if decoding fails.
16+
func DecodeString(ctx context.Context, s string) (decoded string, err error) {
17+
if reflection.IsEmpty(s) {
18+
err = commonerrors.New(commonerrors.ErrEmpty, "the string is empty")
19+
return
20+
}
21+
err = parallelisation.DetermineContextError(ctx)
22+
if err != nil {
23+
return
24+
}
25+
data, err := base64.URLEncoding.DecodeString(s)
26+
if err == nil {
27+
decoded = string(data)
28+
return
29+
}
30+
err = parallelisation.DetermineContextError(ctx)
31+
if err != nil {
32+
return
33+
}
34+
data, err = base64.RawURLEncoding.DecodeString(s)
35+
if err == nil {
36+
decoded = string(data)
37+
return
38+
}
39+
err = parallelisation.DetermineContextError(ctx)
40+
if err != nil {
41+
return
42+
}
43+
data, err = base64.StdEncoding.DecodeString(s)
44+
if err == nil {
45+
decoded = string(data)
46+
return
47+
}
48+
err = parallelisation.DetermineContextError(ctx)
49+
if err != nil {
50+
return
51+
}
52+
data, err = base64.RawStdEncoding.DecodeString(s)
53+
if err == nil {
54+
decoded = string(data)
55+
} else {
56+
trimmed := strings.TrimSuffix(strings.TrimSuffix(s, "="), "=")
57+
if trimmed == s || strings.HasSuffix(trimmed, "=") {
58+
err = commonerrors.WrapError(commonerrors.ErrMarshalling, err, "failed to decode base64 string")
59+
} else {
60+
decoded, err = DecodeString(ctx, trimmed)
61+
}
62+
}
63+
return
64+
}
65+
66+
// DecodeIfEncoded will attempt to decode any string if they are base64 encoded. If not, the string will be returned as is.
67+
// If the string is base64 encoded but the decoding fails, the original string will be returned.
68+
func DecodeIfEncoded(ctx context.Context, s string) (decoded string) {
69+
decoded = s
70+
if IsEncoded(s) {
71+
d, err := DecodeString(ctx, s)
72+
if err == nil {
73+
decoded = d
74+
}
75+
}
76+
return
77+
}
78+
79+
// DecodeRecursively will attempt to decode any string until they are no longer base64 encoded.
80+
func DecodeRecursively(ctx context.Context, s string) (decoded string) {
81+
decoded = s
82+
for {
83+
tmp := DecodeIfEncoded(ctx, decoded)
84+
if decoded == tmp {
85+
return
86+
}
87+
decoded = tmp
88+
}
89+
}
90+
91+
// IsEncoded checks whether a string is encoded or not.
92+
func IsEncoded(s string) bool {
93+
if reflection.IsEmpty(s) {
94+
return false
95+
}
96+
if is.Base64.Validate(s) == nil {
97+
return true
98+
}
99+
_, err := DecodeString(context.Background(), s)
100+
if err == nil {
101+
return true
102+
}
103+
trimmed := strings.TrimSuffix(strings.TrimSuffix(s, "="), "=")
104+
if trimmed == s || strings.HasSuffix(trimmed, "=") {
105+
return false
106+
} else {
107+
return IsEncoded(trimmed)
108+
}
109+
110+
}
111+
112+
func EncodeString(s string) string {
113+
return Encode([]byte(s))
114+
}
115+
116+
func Encode(b []byte) string {
117+
return base64.StdEncoding.EncodeToString(b)
118+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package base64
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"testing"
7+
8+
"github.com/go-faker/faker/v4"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/ARM-software/golang-utils/utils/commonerrors"
13+
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
14+
)
15+
16+
func TestIsBase64Encoded(t *testing.T) {
17+
random := faker.Sentence()
18+
base641 := base64.RawURLEncoding.EncodeToString([]byte(random))
19+
base642 := base64.RawStdEncoding.EncodeToString([]byte(random))
20+
base643 := base64.URLEncoding.EncodeToString([]byte(random))
21+
base644 := base64.StdEncoding.EncodeToString([]byte(random))
22+
tests := []struct {
23+
input string
24+
expected bool
25+
}{
26+
{"U29tZSBkYXRh", true}, // "Some data"
27+
{"SGVsbG8gd29ybGQ=", true}, // "Hello world"
28+
{"U29tZSBkYXRh===", false},
29+
{"", false}, // Empty string
30+
{"NotBase64", false}, // Plain text
31+
{"!@#$%^&*", false}, // Non-Base64 characters
32+
{"U29tZSBkYXRh\n", true}, // Line break
33+
{"V2l0aCB3aGl0ZXNwYWNl", true}, // "With whitespace" (valid if stripped)
34+
{base641, true},
35+
{base642, true},
36+
{base643, true},
37+
{base644, true},
38+
{"U29tZSBkYXRh=", true},
39+
{"U29tZSBkYXRh==", true},
40+
}
41+
42+
for i := range tests {
43+
test := tests[i]
44+
t.Run(test.input, func(t *testing.T) {
45+
if test.expected {
46+
assert.True(t, IsEncoded(test.input))
47+
} else {
48+
assert.False(t, IsEncoded(test.input))
49+
}
50+
})
51+
}
52+
}
53+
54+
func TestDecodeIfBase64(t *testing.T) {
55+
random := faker.Sentence()
56+
base641 := base64.RawURLEncoding.EncodeToString([]byte(random))
57+
base642 := base64.RawStdEncoding.EncodeToString([]byte(random))
58+
base643 := base64.URLEncoding.EncodeToString([]byte(random))
59+
base644 := base64.StdEncoding.EncodeToString([]byte(random))
60+
61+
tests := []struct {
62+
input string
63+
expected string
64+
errors bool
65+
}{
66+
{input: "U29tZSBkYXRh", expected: "Some data"},
67+
{input: "SGVsbG8gd29ybGQ=", expected: "Hello world"},
68+
{input: "VGVzdCBzdHJpbmc=", expected: "Test string"},
69+
{input: "MTIzNDU2", expected: "123456"},
70+
{input: base641, expected: random},
71+
{input: base642, expected: random},
72+
{input: base643, expected: random},
73+
{input: base644, expected: random},
74+
75+
{input: "NotBase64", expected: "NotBase64", errors: true},
76+
{input: "Invalid===", expected: "Invalid===", errors: true},
77+
{input: "", expected: "", errors: true},
78+
{input: "!@#$%^&*", expected: "!@#$%^&*", errors: true},
79+
80+
{input: "U29tZSBkYXRh\n", expected: "Some data"}, // newline is not part of valid base64
81+
{input: "U29tZSBkYXRh=", expected: "Some data"}, // valid with single padding
82+
{input: "U29tZSBkYXRh==", expected: "Some data"}, // valid with double padding
83+
}
84+
85+
for i := range tests {
86+
test := tests[i]
87+
t.Run(test.input, func(t *testing.T) {
88+
result, err := DecodeString(context.Background(), test.input)
89+
assert.Equal(t, test.expected, DecodeIfEncoded(context.Background(), test.input))
90+
if test.errors {
91+
errortest.AssertError(t, err, commonerrors.ErrMarshalling, commonerrors.ErrInvalid, commonerrors.ErrEmpty)
92+
} else {
93+
require.NoError(t, err)
94+
assert.Equal(t, test.expected, result)
95+
}
96+
})
97+
}
98+
99+
t.Run("cancellation", func(t *testing.T) {
100+
ctx, cancel := context.WithCancel(context.Background())
101+
cancel()
102+
_, err := DecodeString(ctx, random)
103+
errortest.AssertError(t, err, commonerrors.ErrCancelled)
104+
assert.Equal(t, random, DecodeIfEncoded(ctx, random))
105+
106+
})
107+
}
108+
109+
func TestDecodeRecursively(t *testing.T) {
110+
randomText := faker.Paragraph()
111+
random, err := faker.RandomInt(1, 10, 1)
112+
require.NoError(t, err)
113+
114+
encodedText := randomText
115+
for i := 0; i < random[0]; i++ {
116+
encodedText = EncodeString(encodedText)
117+
}
118+
119+
assert.NotEqual(t, randomText, encodedText)
120+
assert.Equal(t, randomText, DecodeRecursively(context.Background(), encodedText))
121+
}

utils/http/headers/headers.go

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package headers
22

33
import (
4-
"encoding/base64"
4+
"context"
55
"fmt"
66
"net/http"
77
"strings"
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/ARM-software/golang-utils/utils/collection"
1212
"github.com/ARM-software/golang-utils/utils/commonerrors"
13+
"github.com/ARM-software/golang-utils/utils/encoding/base64"
1314
"github.com/ARM-software/golang-utils/utils/http/headers/useragent"
1415
"github.com/ARM-software/golang-utils/utils/http/schemes"
1516
"github.com/ARM-software/golang-utils/utils/reflection"
@@ -259,7 +260,7 @@ func FetchWebsocketAuthorisation(r *http.Request) (authorisationHeader string) {
259260
if found {
260261
if i < len(subProtocols)-1 {
261262
authorisationHeader = subProtocols[i+1]
262-
if decoded, err := decodeBase64Token(authorisationHeader); err == nil {
263+
if decoded, err := base64.DecodeString(context.Background(), authorisationHeader); err == nil {
263264
authorisationHeader = decoded
264265
}
265266
_ = SetAuthorisationIfNotPresent(r, authorisationHeader)
@@ -272,7 +273,7 @@ func FetchWebsocketAuthorisation(r *http.Request) (authorisationHeader string) {
272273
for j := range subProtocols {
273274
token := strings.TrimPrefix(subProtocols[j], "base64url.bearer.authorization.k8s.io.")
274275
if token != subProtocols[j] {
275-
data, err := decodeBase64Token(token)
276+
data, err := base64.DecodeString(context.Background(), token)
276277
if err == nil {
277278
authorisationHeader = data
278279
_ = SetAuthorisationIfNotPresent(r, authorisationHeader)
@@ -285,29 +286,6 @@ func FetchWebsocketAuthorisation(r *http.Request) (authorisationHeader string) {
285286
return
286287
}
287288

288-
func decodeBase64Token(token string) (decoded string, err error) {
289-
data, err := base64.URLEncoding.DecodeString(token)
290-
if err == nil {
291-
decoded = string(data)
292-
return
293-
}
294-
data, err = base64.RawURLEncoding.DecodeString(token)
295-
if err == nil {
296-
decoded = string(data)
297-
return
298-
}
299-
data, err = base64.StdEncoding.DecodeString(token)
300-
if err == nil {
301-
decoded = string(data)
302-
return
303-
}
304-
data, err = base64.RawStdEncoding.DecodeString(token)
305-
if err == nil {
306-
decoded = string(data)
307-
}
308-
return
309-
}
310-
311289
// SetAuthorisationIfNotPresent sets the value of the `Authorization` header if not already set.
312290
func SetAuthorisationIfNotPresent(r *http.Request, authorisation string) (err error) {
313291
if strings.TrimSpace(FetchAuthorisation(r)) == "" {

utils/validation/rules.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ import (
88
"github.com/go-ozzo/ozzo-validation/v4/is"
99

1010
"github.com/ARM-software/golang-utils/utils/commonerrors"
11+
"github.com/ARM-software/golang-utils/utils/encoding/base64"
1112
)
1213

1314
// IsPort validates whether a value is a port using is.Port from github.com/go-ozzo/ozzo-validation/v4.
14-
// However it supports all base go integer types not just strings.
15+
// However, it supports all base go integer types not just strings.
1516
var IsPort = validation.By(isPort)
1617

1718
func isPort(vRaw any) (err error) {
@@ -37,3 +38,6 @@ func isPort(vRaw any) (err error) {
3738

3839
return
3940
}
41+
42+
// IsBase64 validates whether a value is a base64 encoded string. It is similar to is.Base64 but more generic and robust although less performant.
43+
var IsBase64 = validation.NewStringRuleWithError(base64.IsEncoded, is.ErrBase64)

utils/validation/rules_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
package validation
22

33
import (
4+
"encoding/base64"
45
"testing"
56

7+
"github.com/go-faker/faker/v4"
8+
validation "github.com/go-ozzo/ozzo-validation/v4"
9+
"github.com/go-ozzo/ozzo-validation/v4/is"
610
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
712

813
"github.com/ARM-software/golang-utils/utils/commonerrors"
914
"github.com/ARM-software/golang-utils/utils/commonerrors/errortest"
@@ -53,3 +58,42 @@ func TestCastingToInt(t *testing.T) {
5358
})
5459
}
5560
}
61+
62+
func TestIsBase64Encoded(t *testing.T) {
63+
random := faker.Sentence()
64+
base641 := base64.RawURLEncoding.EncodeToString([]byte(random))
65+
base642 := base64.RawStdEncoding.EncodeToString([]byte(random))
66+
base643 := base64.URLEncoding.EncodeToString([]byte(random))
67+
base644 := base64.StdEncoding.EncodeToString([]byte(random))
68+
tests := []struct {
69+
input string
70+
expected bool
71+
}{
72+
{"U29tZSBkYXRh", true}, // "Some data"
73+
{"SGVsbG8gd29ybGQ=", true}, // "Hello world"
74+
{"U29tZSBkYXRh===", false},
75+
{"", true}, // Empty string
76+
{"NotBase64", false}, // Plain text
77+
{"!@#$%^&*", false}, // Non-Base64 characters
78+
{"U29tZSBkYXRh\n", true}, // Line break
79+
{"V2l0aCB3aGl0ZXNwYWNl", true}, // "With whitespace" (valid if stripped)
80+
{base641, true},
81+
{base642, true},
82+
{base643, true},
83+
{base644, true},
84+
{"U29tZSBkYXRh=", true},
85+
{"U29tZSBkYXRh==", true},
86+
}
87+
88+
for i := range tests {
89+
test := tests[i]
90+
t.Run(test.input, func(t *testing.T) {
91+
err := validation.Validate(test.input, IsBase64)
92+
if test.expected {
93+
require.NoError(t, err)
94+
} else {
95+
errortest.AssertErrorDescription(t, err, is.ErrBase64.Error())
96+
}
97+
})
98+
}
99+
}

0 commit comments

Comments
 (0)