@@ -11,15 +11,15 @@ struct llama_ubatch {
11
11
bool equal_seqs;
12
12
// TODO: whole_seqs for embeddings?
13
13
14
- uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
14
+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
15
15
uint32_t n_seq_tokens; // tokens per sequence
16
16
uint32_t n_seqs;
17
17
18
18
llama_token * token; // [n_tokens]
19
19
float * embd; // [n_embd, n_tokens]
20
20
llama_pos * pos; // [n_tokens]
21
- int32_t * n_seq_id; // [n_seqs]
22
- llama_seq_id ** seq_id; // [n_seqs]
21
+ int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
22
+ llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
23
23
int8_t * output; // [n_tokens]
24
24
};
25
25
@@ -49,13 +49,18 @@ struct llama_sbatch {
49
49
50
50
const llama_batch * batch = nullptr ;
51
51
52
- // buffers for the ubatch
53
- std::vector<llama_token> ubatch_token;
54
- std::vector<float > ubatch_embd;
55
- std::vector<llama_pos> ubatch_pos;
56
- std::vector<int32_t > ubatch_n_seq_id;
57
- std::vector<llama_seq_id *> ubatch_seq_id;
58
- std::vector<int8_t > ubatch_output;
52
+ // buffers for the ubatches
53
+ // TODO: very hacky, this needs a complete rework
54
+ struct ubatch_data {
55
+ std::vector<llama_token> token;
56
+ std::vector<float > embd;
57
+ std::vector<llama_pos> pos;
58
+ std::vector<int32_t > n_seq_id;
59
+ std::vector<llama_seq_id *> seq_id;
60
+ std::vector<int8_t > output;
61
+ };
62
+
63
+ std::vector<ubatch_data> udatas;
59
64
60
65
llama_ubatch reserve_ubatch (size_t n_ubatch, bool has_embd = false );
61
66
0 commit comments