From cc7b78c35dee01c483386100186f248bb19a1b35 Mon Sep 17 00:00:00 2001 From: Joshua Jiahua Hong Date: Mon, 25 Aug 2025 04:00:25 -0400 Subject: [PATCH 1/2] Add sequence padding to BeginForward --- cpp/serve/model.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 11c9f03995..22686588bf 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -262,7 +262,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); ObjectRef embeddings_dref_or_nd; if (!embeddings->IsInstance()) { @@ -372,7 +372,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); // args: embeddings, logit_pos, kv_cache, params ObjectRef result{nullptr}; @@ -422,7 +422,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(std::vector(/*n=*/seq_ids.size(), /*v=*/1)); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); ObjectRef embeddings_dref_or_nd; if (!embeddings->IsInstance()) { @@ -501,7 +501,7 @@ class ModelImpl : public ModelObj { IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_, token_tree_parent_ptr_tuple); ObjectRef embeddings_dref_or_nd; @@ -564,7 +564,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(std::vector(/*n=*/seq_ids.size(), /*v=*/1)); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); // args: embeddings, kv_cache, params ObjectRef result{nullptr}; @@ -624,7 +624,7 @@ class ModelImpl : public ModelObj { IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_, token_tree_parent_ptr_tuple); ObjectRef embeddings_dref_or_nd; @@ -712,7 +712,7 @@ class ModelImpl : public ModelObj { IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_, token_tree_parent_ptr_tuple); // args: embeddings, logit_pos, kv_cache, params @@ -827,7 +827,7 @@ class ModelImpl : public ModelObj { // Run KV receive preparation. ObjectRef ret; - ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length).cast(); + ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length, seqlen_padding_factor_).cast(); IntTuple compressed_kv_append_metadata; if (ft_.use_disco) { compressed_kv_append_metadata = Downcast(ret)->DebugGetFromRemote(0).cast(); From 64a4bcb6dac7a0269d074ac63abc2b739f70c624 Mon Sep 17 00:00:00 2001 From: Joshua Jiahua Hong Date: Sun, 21 Sep 2025 01:44:13 -0400 Subject: [PATCH 2/2] BatchDecode padding --- cpp/metadata/model.cc | 2 +- cpp/serve/model.cc | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/cpp/metadata/model.cc b/cpp/metadata/model.cc index 9dbfb5bad6..5cd939ed3f 100644 --- a/cpp/metadata/model.cc +++ b/cpp/metadata/model.cc @@ -86,7 +86,7 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, if (metadata.count("attention_sink_size")) // remove after sink is decoupled from model lib result.attention_sink_size = json::Lookup(metadata, "attention_sink_size"); result.seqlen_padding_factor = - json::LookupOrDefault(metadata, "seqlen_padding_factor", 1); + json::LookupOrDefault(metadata, "seqlen_padding_factor", 16); result.tensor_parallel_shards = json::Lookup(metadata, "tensor_parallel_shards"); result.pipeline_parallel_stages = json::LookupOrDefault(metadata, "pipeline_parallel_stages", 1); diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 22686588bf..0024a601dd 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -410,6 +410,12 @@ class ModelImpl : public ModelObj { NDArray BatchDecode(const ObjectRef& embeddings, const std::vector& seq_ids) final { NVTXScopedRange nvtx_scope("BatchDecode num_seqs=" + std::to_string(seq_ids.size())); int num_sequence = seq_ids.size(); + + bool padded = num_sequence % seqlen_padding_factor_ != 0; + if (padded) { + num_sequence = (num_sequence + seqlen_padding_factor_ - 1) / seqlen_padding_factor_ * + seqlen_padding_factor_; + } CHECK(ft_.decode_func_.defined()) << "`decode_with_embed` function is not found in the model. Please make sure the model is " @@ -443,7 +449,7 @@ class ModelImpl : public ModelObj { // args: embeddings, kv_cache, params ObjectRef ret; - if (seq_ids.size() == 1) { + if (num_sequence == 1) { ret = ft_.single_batch_decode_func_(embeddings_dref_or_nd, kv_cache_, params_) .cast(); } else { @@ -468,9 +474,14 @@ class ModelImpl : public ModelObj { } ft_.kv_cache_end_forward_func_(kv_cache_); + if (padded) { + // logits shape: (padded_batch, 1, vocab_size_) + // Slice to (real_batch, 1, vocab_size_) + logits = logits.CreateView({seq_ids.size(), 1, vocab_size_}, logits->dtype); + } // logits: (b, 1, v) ICHECK_EQ(logits->ndim, 3); - ICHECK_EQ(logits->shape[0], num_sequence); + ICHECK_EQ(logits->shape[0], seq_ids.size()); ICHECK_EQ(logits->shape[1], 1); return logits; }