23
23
24
24
namespace gcpp {
25
25
26
+ void KVCache::ZeroGriffinCache () {
27
+ if (conv1d_cache_size != 0 ) {
28
+ hwy::ZeroBytes (conv1d_cache.get (),
29
+ conv1d_cache_size * sizeof (conv1d_cache[0 ]));
30
+ }
31
+ if (rglru_cache_size != 0 ) {
32
+ hwy::ZeroBytes (rglru_cache.get (),
33
+ rglru_cache_size * sizeof (rglru_cache[0 ]));
34
+ }
35
+ }
36
+
26
37
// prefill_tbatch_size is the maximum number of tokens from one query to
27
38
// prefill at a time.
28
39
KVCache KVCache::Create (const ModelConfig& weights_config,
@@ -37,9 +48,9 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
37
48
kv_cache.kv_cache =
38
49
hwy::AllocateAligned<float >(kv_cache.seq_len * size_cache_pos);
39
50
}
40
- size_t num_griffin_layers = weights_config.NumLayersOfType (
41
- LayerAttentionType::kGriffinRecurrentBlock );
42
51
52
+ const size_t num_griffin_layers = weights_config.NumLayersOfType (
53
+ LayerAttentionType::kGriffinRecurrentBlock );
43
54
// TODO(patrickms): Add query batching support for Griffin.
44
55
if (num_griffin_layers > 0 ) {
45
56
size_t conv1d_width = 0 ;
@@ -49,20 +60,18 @@ KVCache KVCache::Create(const ModelConfig& weights_config,
49
60
const size_t conv1d_cache_size =
50
61
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1 ) *
51
62
weights_config.model_dim ;
63
+ kv_cache.conv1d_cache_size = conv1d_cache_size;
52
64
if (conv1d_cache_size != 0 ) {
53
65
kv_cache.conv1d_cache = hwy::AllocateAligned<float >(conv1d_cache_size);
54
- hwy::ZeroBytes (kv_cache.conv1d_cache .get (),
55
- conv1d_cache_size * sizeof (kv_cache.conv1d_cache [0 ]));
56
66
}
57
67
58
68
const size_t rglru_cache_size =
59
69
num_griffin_layers * weights_config.model_dim ;
70
+ kv_cache.rglru_cache_size = rglru_cache_size;
60
71
if (rglru_cache_size != 0 ) {
61
72
kv_cache.rglru_cache = hwy::AllocateAligned<float >(rglru_cache_size);
62
- hwy::ZeroBytes (kv_cache.rglru_cache .get (),
63
- rglru_cache_size * sizeof (kv_cache.rglru_cache [0 ]));
64
73
}
65
- } // kGriffinLayers
74
+ } // num_griffin_layers
66
75
67
76
return kv_cache;
68
77
}
0 commit comments