diff --git a/csrf.go b/csrf.go index 5dda254..d32362a 100644 --- a/csrf.go +++ b/csrf.go @@ -1,6 +1,7 @@ package csrf import ( + "encoding/base64" "context" "errors" "fmt" @@ -100,6 +101,7 @@ type options struct { // http.Cookie field instead of the "correct" HTTPOnly name that golint suggests. HttpOnly bool Secure bool + URLSafe bool SameSite SameSiteMode RequestHeader string FieldName string @@ -248,7 +250,11 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Save the masked token to the request context - r = contextSave(r, tokenKey, mask(realToken, r)) + encoding := base64.StdEncoding + if cs.opts.URLSafe { + encoding = base64.URLEncoding + } + r = contextSave(r, tokenKey, mask(realToken, r, encoding)) // Save the field name to the request context r = contextSave(r, formKey, cs.opts.FieldName) diff --git a/helpers.go b/helpers.go index 99005ee..3378f0d 100644 --- a/helpers.go +++ b/helpers.go @@ -74,7 +74,7 @@ func TemplateField(r *http.Request) template.HTML { // token and returning them together as a 64-byte slice. This effectively // randomises the token on a per-request basis without breaking multiple browser // tabs/windows. -func mask(realToken []byte, _ *http.Request) string { +func mask(realToken []byte, _ *http.Request, encoding *base64.Encoding) string { otp, err := generateRandomBytes(tokenLength) if err != nil { return "" @@ -83,7 +83,7 @@ func mask(realToken []byte, _ *http.Request) string { // XOR the OTP with the real token to generate a masked token. Append the // OTP to the front of the masked token to allow unmasking in the subsequent // request. - return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...)) + return encoding.EncodeToString(append(otp, xorToken(otp, realToken)...)) } // unmask splits the issued token (one-time-pad + masked token) and returns the @@ -129,7 +129,13 @@ func (cs *csrf) requestToken(r *http.Request) ([]byte, error) { // Decode the "issued" (pad + masked) token sent in the request. Return a // nil byte slice on a decoding error (this will fail upstream). - decoded, err := base64.StdEncoding.DecodeString(issued) + encoding := base64.StdEncoding + + if cs.opts.URLSafe { + encoding = base64.URLEncoding + } + + decoded, err := encoding.DecodeString(issued) if err != nil { return nil, err } diff --git a/helpers_test.go b/helpers_test.go index f40c996..72937ae 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -136,8 +136,10 @@ func TestMaskUnmaskTokens(t *testing.T) { t.Fatal(err) } - issued := mask(realToken, nil) - decoded, err := base64.StdEncoding.DecodeString(issued) + encoding := base64.StdEncoding + + issued := mask(realToken, nil, encoding) + decoded, err := encoding.DecodeString(issued) if err != nil { t.Fatal(err) } diff --git a/options.go b/options.go index c61d301..eaeabd5 100644 --- a/options.go +++ b/options.go @@ -131,6 +131,13 @@ func TrustedOrigins(origins []string) Option { } } +// URLSafe changes the base64 encoding format ( URL safe ) of the CSRF token. +func URLSafe(s bool) Option { + return func(cs *csrf) { + cs.opts.URLSafe = s + } +} + // setStore sets the store used by the CSRF middleware. // Note: this is private (for now) to allow for internal API changes. func setStore(s store) Option { diff --git a/options_test.go b/options_test.go index d133fd1..4ab5958 100644 --- a/options_test.go +++ b/options_test.go @@ -99,3 +99,23 @@ func TestMaxAge(t *testing.T) { }) } + +func TestURLSafe(t *testing.T) { + t.Run("Ensure the default URLSafe is applied", func(t *testing.T) { + handler := Protect(testKey)(nil) + cs := handler.(*csrf) + + if cs.opts.URLSafe != false { + t.Fatalf("default URLSafe not applied: got %v (want %v)", cs.opts.URLSafe, false) + } + }) + + t.Run("Support an explicit URLSafe of true", func(t *testing.T) { + handler := Protect(testKey, URLSafe(true))(nil) + cs := handler.(*csrf) + + if cs.opts.URLSafe != true { + t.Fatalf("URLSafe not applied: got %v (want %v)", cs.opts.URLSafe, true) + } + }) +}