Skip to content

Commit e3b3b39

Browse files
committed
Make storage layer context aware
1 parent 083bc9f commit e3b3b39

File tree

9 files changed

+367
-66
lines changed

9 files changed

+367
-66
lines changed

gateway/auth_manager.go

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -161,35 +161,80 @@ func (b *DefaultSessionManager) RemoveSession(orgID string, keyName string, hash
161161

162162
// SessionDetail returns the session detail using the storage engine (either in memory or Redis)
163163
func (b *DefaultSessionManager) SessionDetail(orgID string, keyName string, hashed bool) (user.SessionState, bool) {
164-
return b.fetchSessionDetail(nil, orgID, keyName, hashed)
164+
var jsonKeyVal string
165+
var err error
166+
keyId := keyName
167+
168+
// get session by key
169+
if hashed {
170+
jsonKeyVal, err = b.store.GetRawKey(b.store.GetKeyPrefix() + keyName)
171+
} else {
172+
if storage.TokenOrg(keyName) != orgID {
173+
// try to get legacy and new format key at once
174+
toSearchList := []string{}
175+
if !b.Gw.GetConfig().DisableKeyActionsByUsername {
176+
toSearchList = append(toSearchList, b.Gw.generateToken(orgID, keyName))
177+
}
178+
179+
toSearchList = append(toSearchList, keyName)
180+
for _, fallback := range b.Gw.GetConfig().HashKeyFunctionFallback {
181+
if !b.Gw.GetConfig().DisableKeyActionsByUsername {
182+
toSearchList = append(toSearchList, b.Gw.generateToken(orgID, keyName, fallback))
183+
}
184+
}
185+
186+
var jsonKeyValList []string
187+
188+
jsonKeyValList, err = b.store.GetMultiKey(toSearchList)
189+
// pick the 1st non empty from the returned list
190+
for idx, val := range jsonKeyValList {
191+
if val != "" {
192+
jsonKeyVal = val
193+
keyId = toSearchList[idx]
194+
break
195+
}
196+
}
197+
} else {
198+
// key is not an imported one
199+
jsonKeyVal, err = b.store.GetKey(keyName)
200+
}
201+
}
202+
203+
if err != nil {
204+
log.WithFields(logrus.Fields{
205+
"prefix": "auth-mgr",
206+
"inbound-key": b.Gw.obfuscateKey(keyName),
207+
"err": err,
208+
}).Debug("Could not get session detail, key not found")
209+
return user.SessionState{}, false
210+
}
211+
session := &user.SessionState{}
212+
if err := json.Unmarshal([]byte(jsonKeyVal), &session); err != nil {
213+
log.Error("Couldn't unmarshal session object (may be cache miss): ", err)
214+
return user.SessionState{}, false
215+
}
216+
session.KeyID = keyId
217+
return session.Clone(), true
165218
}
166219

167220
// SessionDetailContext returns the session detail using the storage engine with context support for cancellation
168221
func (b *DefaultSessionManager) SessionDetailContext(ctx context.Context, orgID string, keyName string, hashed bool) (user.SessionState, bool) {
169-
return b.fetchSessionDetail(ctx, orgID, keyName, hashed)
170-
}
171-
172-
// fetchSessionDetail is the internal implementation shared by SessionDetail and SessionDetailContext
173-
func (b *DefaultSessionManager) fetchSessionDetail(ctx context.Context, orgID string, keyName string, hashed bool) (user.SessionState, bool) {
174-
if ctx != nil {
175-
select {
176-
case <-ctx.Done():
177-
log.WithFields(logrus.Fields{
178-
"prefix": "auth-mgr",
179-
"inbound-key": b.Gw.obfuscateKey(keyName),
180-
}).Debug("Context cancelled")
181-
return user.SessionState{}, false
182-
default:
183-
}
222+
select {
223+
case <-ctx.Done():
224+
log.WithFields(logrus.Fields{
225+
"prefix": "auth-mgr",
226+
"inbound-key": b.Gw.obfuscateKey(keyName),
227+
}).Debug("Context cancelled")
228+
return user.SessionState{}, false
229+
default:
184230
}
185231

186232
var jsonKeyVal string
187233
var err error
188234
keyId := keyName
189235

190-
// get session by key
191236
if hashed {
192-
jsonKeyVal, err = b.store.GetRawKey(b.store.GetKeyPrefix() + keyName)
237+
jsonKeyVal, err = b.store.GetRawKeyContext(ctx, b.store.GetKeyPrefix()+keyName)
193238
} else {
194239
if storage.TokenOrg(keyName) != orgID {
195240
// try to get legacy and new format key at once
@@ -207,7 +252,7 @@ func (b *DefaultSessionManager) fetchSessionDetail(ctx context.Context, orgID st
207252

208253
var jsonKeyValList []string
209254

210-
jsonKeyValList, err = b.store.GetMultiKey(toSearchList)
255+
jsonKeyValList, err = b.store.GetMultiKeyContext(ctx, toSearchList)
211256
// pick the 1st non empty from the returned list
212257
for idx, val := range jsonKeyValList {
213258
if val != "" {
@@ -218,7 +263,7 @@ func (b *DefaultSessionManager) fetchSessionDetail(ctx context.Context, orgID st
218263
}
219264
} else {
220265
// key is not an imported one
221-
jsonKeyVal, err = b.store.GetKey(keyName)
266+
jsonKeyVal, err = b.store.GetKeyContext(ctx, keyName)
222267
}
223268
}
224269

gateway/auth_manager_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gateway
22

33
import (
4+
"context"
45
"crypto/x509"
56
"encoding/json"
67
"fmt"
@@ -578,3 +579,15 @@ func (c *countingStorageHandler) AppendToSet(s string, s2 string) {}
578579
func (c *countingStorageHandler) Exists(s string) (bool, error) {
579580
return false, nil
580581
}
582+
583+
func (c *countingStorageHandler) GetKeyContext(ctx context.Context, s string) (string, error) {
584+
return "", nil
585+
}
586+
587+
func (c *countingStorageHandler) GetRawKeyContext(ctx context.Context, key string) (string, error) {
588+
return "", nil
589+
}
590+
591+
func (c *countingStorageHandler) GetMultiKeyContext(ctx context.Context, keys []string) ([]string, error) {
592+
return nil, nil
593+
}

gateway/ldap_auth_handler.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gateway
22

33
import (
4+
"context"
45
"errors"
56
"strings"
67

@@ -242,3 +243,27 @@ func (l LDAPStorageHandler) Exists(keyName string) (bool, error) {
242243
log.Error("Not implemented")
243244
return false, nil
244245
}
246+
247+
// GetKeyContext retrieves a key with context support.
248+
func (l *LDAPStorageHandler) GetKeyContext(ctx context.Context, keyName string) (string, error) {
249+
if err := ctx.Err(); err != nil {
250+
return "", err
251+
}
252+
return l.GetKey(keyName)
253+
}
254+
255+
// GetRawKeyContext retrieves a raw key with context support.
256+
func (l *LDAPStorageHandler) GetRawKeyContext(ctx context.Context, keyName string) (string, error) {
257+
if err := ctx.Err(); err != nil {
258+
return "", err
259+
}
260+
return l.GetRawKey(keyName)
261+
}
262+
263+
// GetMultiKeyContext retrieves multiple keys with context support.
264+
func (l *LDAPStorageHandler) GetMultiKeyContext(ctx context.Context, keyNames []string) ([]string, error) {
265+
if err := ctx.Err(); err != nil {
266+
return nil, err
267+
}
268+
return l.GetMultiKey(keyNames)
269+
}

gateway/rpc_storage_handler.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package gateway
33

44
import (
5+
"context"
56
"encoding/json"
67
"fmt"
78
"net/http"
@@ -1274,3 +1275,27 @@ func (r *RPCStorageHandler) Exists(keyName string) (bool, error) {
12741275
log.Error("Not implemented")
12751276
return false, nil
12761277
}
1278+
1279+
// GetKeyContext retrieves a key with context support.
1280+
func (r *RPCStorageHandler) GetKeyContext(ctx context.Context, keyName string) (string, error) {
1281+
if err := ctx.Err(); err != nil {
1282+
return "", err
1283+
}
1284+
return r.GetKey(keyName)
1285+
}
1286+
1287+
// GetRawKeyContext retrieves a raw key with context support.
1288+
func (r *RPCStorageHandler) GetRawKeyContext(ctx context.Context, keyName string) (string, error) {
1289+
if err := ctx.Err(); err != nil {
1290+
return "", err
1291+
}
1292+
return r.GetRawKey(keyName)
1293+
}
1294+
1295+
// GetMultiKeyContext retrieves multiple keys with context support.
1296+
func (r *RPCStorageHandler) GetMultiKeyContext(ctx context.Context, keyNames []string) ([]string, error) {
1297+
if err := ctx.Err(); err != nil {
1298+
return nil, err
1299+
}
1300+
return r.GetMultiKey(keyNames)
1301+
}

storage/dummy.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package storage
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
)
@@ -256,3 +257,18 @@ func (s *DummyStorage) GetKeys(pattern string) (keys []string) {
256257

257258
return keys
258259
}
260+
261+
// GetKeyContext retrieves the value for a given key with context support (delegates to GetKey).
262+
func (s *DummyStorage) GetKeyContext(_ context.Context, key string) (string, error) {
263+
return s.GetKey(key)
264+
}
265+
266+
// GetRawKeyContext retrieves a raw key value with context support (delegates to GetRawKey).
267+
func (s *DummyStorage) GetRawKeyContext(_ context.Context, key string) (string, error) {
268+
return s.GetRawKey(key)
269+
}
270+
271+
// GetMultiKeyContext retrieves multiple keys with context support (delegates to GetMultiKey).
272+
func (s *DummyStorage) GetMultiKeyContext(_ context.Context, keys []string) ([]string, error) {
273+
return s.GetMultiKey(keys)
274+
}

storage/mdcb_storage.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package storage
22

33
import (
4+
"context"
45
"errors"
56
"strings"
67

@@ -289,3 +290,27 @@ func (m MdcbStorage) getFromRPCAndCache(key string) (string, error) {
289290
func (m MdcbStorage) getFromLocal(key string) (string, error) {
290291
return m.local.GetKey(key)
291292
}
293+
294+
// GetKeyContext retrieves a key with context support.
295+
func (m MdcbStorage) GetKeyContext(ctx context.Context, key string) (string, error) {
296+
if err := ctx.Err(); err != nil {
297+
return "", err
298+
}
299+
return m.GetKey(key)
300+
}
301+
302+
// GetRawKeyContext retrieves a raw key with context support.
303+
func (m MdcbStorage) GetRawKeyContext(ctx context.Context, key string) (string, error) {
304+
if err := ctx.Err(); err != nil {
305+
return "", err
306+
}
307+
return m.GetRawKey(key)
308+
}
309+
310+
// GetMultiKeyContext retrieves multiple keys with context support.
311+
func (m MdcbStorage) GetMultiKeyContext(ctx context.Context, keys []string) ([]string, error) {
312+
if err := ctx.Err(); err != nil {
313+
return nil, err
314+
}
315+
return m.GetMultiKey(keys)
316+
}

0 commit comments

Comments
 (0)