Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/metadata/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata,
if (metadata.count("attention_sink_size")) // remove after sink is decoupled from model lib
result.attention_sink_size = json::Lookup<int64_t>(metadata, "attention_sink_size");
result.seqlen_padding_factor =
json::LookupOrDefault<int64_t>(metadata, "seqlen_padding_factor", 1);
json::LookupOrDefault<int64_t>(metadata, "seqlen_padding_factor", 16);
result.tensor_parallel_shards = json::Lookup<int64_t>(metadata, "tensor_parallel_shards");
result.pipeline_parallel_stages =
json::LookupOrDefault<int64_t>(metadata, "pipeline_parallel_stages", 1);
Expand Down
31 changes: 21 additions & 10 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_);

ObjectRef embeddings_dref_or_nd;
if (!embeddings->IsInstance<DRefObj>()) {
Expand Down Expand Up @@ -372,7 +372,7 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_);

// args: embeddings, logit_pos, kv_cache, params
ObjectRef result{nullptr};
Expand Down Expand Up @@ -410,6 +410,12 @@ class ModelImpl : public ModelObj {
Tensor BatchDecode(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids) final {
NVTXScopedRange nvtx_scope("BatchDecode num_seqs=" + std::to_string(seq_ids.size()));
int num_sequence = seq_ids.size();

bool padded = num_sequence % seqlen_padding_factor_ != 0;
if (padded) {
num_sequence = (num_sequence + seqlen_padding_factor_ - 1) / seqlen_padding_factor_ *
seqlen_padding_factor_;
}

CHECK(ft_.decode_func_.defined())
<< "`decode_with_embed` function is not found in the model. Please make sure the model is "
Expand All @@ -422,7 +428,7 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(std::vector<int64_t>(/*n=*/seq_ids.size(), /*v=*/1));
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_);

ObjectRef embeddings_dref_or_nd;
if (!embeddings->IsInstance<DRefObj>()) {
Expand All @@ -443,7 +449,7 @@ class ModelImpl : public ModelObj {

// args: embeddings, kv_cache, params
ObjectRef ret;
if (seq_ids.size() == 1) {
if (num_sequence == 1) {
ret = ft_.single_batch_decode_func_(embeddings_dref_or_nd, kv_cache_, params_)
.cast<ObjectRef>();
} else {
Expand All @@ -468,9 +474,14 @@ class ModelImpl : public ModelObj {
}
ft_.kv_cache_end_forward_func_(kv_cache_);

if (padded) {
// logits shape: (padded_batch, 1, vocab_size_)
// Slice to (real_batch, 1, vocab_size_)
logits = logits.CreateView({seq_ids.size(), 1, vocab_size_}, logits->dtype);
}
// logits: (b, 1, v)
ICHECK_EQ(logits->ndim, 3);
ICHECK_EQ(logits->shape[0], num_sequence);
ICHECK_EQ(logits->shape[0], seq_ids.size());
ICHECK_EQ(logits->shape[1], 1);
return logits;
}
Expand Down Expand Up @@ -501,7 +512,7 @@ class ModelImpl : public ModelObj {
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_,
token_tree_parent_ptr_tuple);

ObjectRef embeddings_dref_or_nd;
Expand Down Expand Up @@ -564,7 +575,7 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(std::vector<int64_t>(/*n=*/seq_ids.size(), /*v=*/1));
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_);

// args: embeddings, kv_cache, params
ObjectRef result{nullptr};
Expand Down Expand Up @@ -624,7 +635,7 @@ class ModelImpl : public ModelObj {
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_,
token_tree_parent_ptr_tuple);

ObjectRef embeddings_dref_or_nd;
Expand Down Expand Up @@ -712,7 +723,7 @@ class ModelImpl : public ModelObj {
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_,
token_tree_parent_ptr_tuple);

// args: embeddings, logit_pos, kv_cache, params
Expand Down Expand Up @@ -827,7 +838,7 @@ class ModelImpl : public ModelObj {

// Run KV receive preparation.
ObjectRef ret;
ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length).cast<ObjectRef>();
ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length, seqlen_padding_factor_).cast<ObjectRef>();
IntTuple compressed_kv_append_metadata;
if (ft_.use_disco) {
compressed_kv_append_metadata = Downcast<DRef>(ret)->DebugGetFromRemote(0).cast<IntTuple>();
Expand Down
Loading