diff --git a/cpp/metadata/model.cc b/cpp/metadata/model.cc index 0cf0186336..f4ef42e848 100644 --- a/cpp/metadata/model.cc +++ b/cpp/metadata/model.cc @@ -87,7 +87,7 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, if (metadata.count("attention_sink_size")) // remove after sink is decoupled from model lib result.attention_sink_size = json::Lookup(metadata, "attention_sink_size"); result.seqlen_padding_factor = - json::LookupOrDefault(metadata, "seqlen_padding_factor", 1); + json::LookupOrDefault(metadata, "seqlen_padding_factor", 16); result.tensor_parallel_shards = json::Lookup(metadata, "tensor_parallel_shards"); result.pipeline_parallel_stages = json::LookupOrDefault(metadata, "pipeline_parallel_stages", 1); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 568501dd7e..6bf1c99c77 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -262,7 +262,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); ObjectRef embeddings_dref_or_nd; if (!embeddings->IsInstance()) { @@ -372,7 +372,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); // args: embeddings, logit_pos, kv_cache, params ObjectRef result{nullptr}; @@ -410,6 +410,12 @@ class ModelImpl : public ModelObj { Tensor BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { NVTXScopedRange nvtx_scope("BatchDecode num_seqs=" + std::to_string(seq_ids.size())); int num_sequence = seq_ids.size(); + + bool padded = num_sequence % seqlen_padding_factor_ != 0; + if (padded) { + num_sequence = (num_sequence + seqlen_padding_factor_ - 1) / seqlen_padding_factor_ * + seqlen_padding_factor_; + } CHECK(ft_.decode_func_.defined()) << "`decode_with_embed` function is not found in the model. Please make sure the model is " @@ -422,7 +428,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(std::vector(/*n=*/seq_ids.size(), /*v=*/1)); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); ObjectRef embeddings_dref_or_nd; if (!embeddings->IsInstance()) { @@ -443,7 +449,7 @@ class ModelImpl : public ModelObj { // args: embeddings, kv_cache, params ObjectRef ret; - if (seq_ids.size() == 1) { + if (num_sequence == 1) { ret = ft_.single_batch_decode_func_(embeddings_dref_or_nd, kv_cache_, params_) .cast(); } else { @@ -468,9 +474,14 @@ class ModelImpl : public ModelObj { } ft_.kv_cache_end_forward_func_(kv_cache_); + if (padded) { + // logits shape: (padded_batch, 1, vocab_size_) + // Slice to (real_batch, 1, vocab_size_) + logits = logits.CreateView({seq_ids.size(), 1, vocab_size_}, logits->dtype); + } // logits: (b, 1, v) ICHECK_EQ(logits->ndim, 3); - ICHECK_EQ(logits->shape[0], num_sequence); + ICHECK_EQ(logits->shape[0], seq_ids.size()); ICHECK_EQ(logits->shape[1], 1); return logits; } @@ -501,7 +512,7 @@ class ModelImpl : public ModelObj { IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_, token_tree_parent_ptr_tuple); ObjectRef embeddings_dref_or_nd; @@ -564,7 +575,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(std::vector(/*n=*/seq_ids.size(), /*v=*/1)); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); // args: embeddings, kv_cache, params ObjectRef result{nullptr}; @@ -624,7 +635,7 @@ class ModelImpl : public ModelObj { IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_, token_tree_parent_ptr_tuple); ObjectRef embeddings_dref_or_nd; @@ -712,7 +723,7 @@ class ModelImpl : public ModelObj { IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_, token_tree_parent_ptr_tuple); // args: embeddings, logit_pos, kv_cache, params @@ -827,7 +838,7 @@ class ModelImpl : public ModelObj { // Run KV receive preparation. ObjectRef ret; - ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length).cast(); + ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length, seqlen_padding_factor_).cast(); IntTuple compressed_kv_append_metadata; if (ft_.use_disco) { compressed_kv_append_metadata = Downcast(ret)->DebugGetFromRemote(0).cast();