Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ class BatchConfig {
// Maximum possible values for different parameters
// These maximum values are used for copying BatchConfig
// across workers
inline static int const MAX_NUM_REQUESTS = 64;
inline static int const MAX_NUM_REQUESTS = 96;
inline static int const MAX_NUM_TOKENS = 1024;
inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 8;
inline static int const MAX_TREE_DEPTH = 8;
inline static int const MAX_TREE_WIDTH = 16;
inline static int const MAX_TREE_DEPTH = 10;
inline static int const MAX_TREE_WIDTH = 12;
inline static int const MAX_SPEC_TREE_TOKEN_NUM =
MAX_TREE_DEPTH * MAX_TREE_WIDTH;
inline static int const MAX_K_LOGITS = 16;
Expand Down
3 changes: 1 addition & 2 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,7 @@ flexflow_tensor_t *flexflow_model_add_add_bias_residual_layer_norm(

flexflow_tensor_t
flexflow_model_add_sigmoid_silu_multi(flexflow_model_t handle,
flexflow_tensor_t const input1,
flexflow_tensor_t const input2,
flexflow_tensor_t const input,
int intermediate_size,
char const *name);

Expand Down
13 changes: 5 additions & 8 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,7 @@ class FFModel {
DataType data_type = DT_NONE,
char const *name = NULL);
// Add a sigmoid_silu_multi layer
Tensor sigmoid_silu_multi(Tensor const input1,
Tensor const input2,
Tensor sigmoid_silu_multi(Tensor const input,
int intermediate_size,
DataType data_type = DT_NONE,
char const *name = NULL);
Expand Down Expand Up @@ -711,7 +710,7 @@ class FFModel {
Initializer *kernel_initializer = NULL,
char const *name = NULL);
Tensor inc_multihead_self_attention(
const Tensor input,
Tensor const input,
int embed_dim,
int num_heads,
int kdim = 0,
Expand All @@ -730,7 +729,7 @@ class FFModel {
bool streaming_cache = false,
char const *name = NULL);
Tensor spec_inc_multihead_self_attention(
const Tensor input,
Tensor const input,
int embed_dim,
int num_heads,
int kdim = 0,
Expand Down Expand Up @@ -1211,10 +1210,8 @@ class FFModel {
std::pair<std::pair<ParallelTensorShape, ParallelTensorShape>,
AddBiasResidualLayerNormParams>,
AddBiasResidualLayerNorm *>,
std::unordered_map<
std::pair<std::pair<ParallelTensorShape, ParallelTensorShape>,
SigmoidSiluMultiParams>,
SigmoidSiluMulti *>,
std::unordered_map<std::pair<ParallelTensorShape, SigmoidSiluMultiParams>,
SigmoidSiluMulti *>,
std::unordered_map<std::pair<ParallelTensorShape, LinearParams>,
Linear *>,
std::unordered_map<std::pair<ParallelTensorShape, Pool2DParams>,
Expand Down
2 changes: 1 addition & 1 deletion include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
quantized_weightSize;
int hidden_size, qk_dim, v_dim, o_dim;
int global_num_q_heads, global_num_kv_heads, num_q_heads, num_kv_heads,
local_hidden_size;
local_hidden_size, total_heads_dim;
bool *has_load_weights;
RotaryEmbeddingMeta *rotary_embedding_meta;
bool *qkv_bias;
Expand Down
13 changes: 5 additions & 8 deletions include/flexflow/ops/sigmoid_silu_multi.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@ class SigmoidSiluMultiMeta;
class SigmoidSiluMulti : public Op {
public:
using Params = SigmoidSiluMultiParams;
using Input = std::pair<ParallelTensor, ParallelTensor>;
using Input = ParallelTensor;
SigmoidSiluMulti(FFModel &model,
Params const &params,
Input const &inputs,
Input const &input,
char const *name = nullptr);
SigmoidSiluMulti(FFModel &model,
LayerID const &_layer_guid,
const ParallelTensor _input1,
const ParallelTensor _input2,
ParallelTensor const _input,
int _intermediate_size,
int _tensor_parallelism_degree,
char const *name = nullptr);
Expand Down Expand Up @@ -63,13 +62,11 @@ class SigmoidSiluMulti : public Op {
template <typename T>
static void inference_kernel(SigmoidSiluMultiMeta const *m,
int num_elements,
T const *input1_ptr,
T const *input2_ptr,
T const *input_ptr,
T *output_ptr,
ffStream_t stream);
static void inference_kernel_wrapper(SigmoidSiluMultiMeta const *m,
GenericTensorAccessorR const &input1,
GenericTensorAccessorR const &input2,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
int token_size);

Expand Down
3 changes: 1 addition & 2 deletions include/flexflow/ops/sigmoid_silu_multi_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ struct SigmoidSiluMultiParams {
LayerID layer_guid;
int intermediate_size, tensor_parallelism_degree;
char name[MAX_OPNAME];
bool is_valid(
std::pair<ParallelTensorShape, ParallelTensorShape> const &) const;
bool is_valid(ParallelTensorShape const &) const;
};

bool operator==(SigmoidSiluMultiParams const &, SigmoidSiluMultiParams const &);
Expand Down
5 changes: 5 additions & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,12 @@ struct ProfileInfo {
std::vector<double> tree_operation_step_times;
// Number of generated tokens at each step
std::vector<int> generated_tokens_per_step;
// Number of proposed tokens at each step
std::vector<int> tokens_in_verification_per_step;
// To calculate the E2E time of serving
long long server_start_time = 0;
long long server_end_time = 0;
int prefilling_steps = 0;
};

class RequestManager {
Expand Down Expand Up @@ -444,6 +447,7 @@ class RequestManager {
// configuration parameters
int max_requests_per_batch;
int max_tokens_per_batch;
int config_max_token_per_batch;
int max_tokens_per_ssm_batch;
int max_tokens_per_prefilling_batch;
int max_spec_tree_token_num;
Expand Down Expand Up @@ -586,6 +590,7 @@ class RequestManager {
void prune_token_tree_greedy();
void add_tokens_toward_slo(RequestGuid guid,
int &budget,
double num_tokens_to_decode,
int num_req_with_slo);
void add_tokens_toward_memory_occupancy(int budget);
void add_tokens_toward_goodput(int budget);
Expand Down
13 changes: 9 additions & 4 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,14 @@ void FlexFlow::top_level_task(Task const *task,
/*ignore_comments */ true);
ModelType model_type = ModelType::UNKNOWN;
auto architectures = model_config["architectures"];
bool qwen = false;
for (auto const &str : architectures) {
if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" ||
str == "MistralForCausalLM") {
str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") {
model_type = ModelType::LLAMA;
if (str == "Qwen2ForCausalLM") {
qwen = true;
}
break;
} else if (str == "OPTForCausalLM") {
model_type = ModelType::OPT;
Expand Down Expand Up @@ -361,8 +365,8 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_baseline_latency(baseline_latency_ms);
rm->set_ssm_spec_latency(ssm_spec_latency_ms);
rm->set_llm_verify_latency(llm_verify_latency_ms);
rm->set_max_tree_depth(8);
rm->set_max_tree_width(16);
rm->set_max_tree_depth(2);
rm->set_max_tree_width(2);
rm->set_verbose(verbose);
rm->set_streaming_cache(streaming_cache);
rm->set_fcfs_slo(fcfs_slo);
Expand All @@ -379,7 +383,8 @@ void FlexFlow::top_level_task(Task const *task,
INC_DECODING_MODE,
generationConfig,
streaming_cache,
use_full_precision);
use_full_precision,
/*qkv_bias*/ qwen);
} else if (model_type == ModelType::OPT) {
OPT::create_opt_model(model,
config_filepath,
Expand Down
51 changes: 20 additions & 31 deletions inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ void LLAMA::create_llama_model(FFModel &ff,
InferenceMode mode,
GenerationConfig generation_config,
bool streaming_cache,
bool use_full_precision) {
bool use_full_precision,
bool qkv_bias) {
// do not apply cpu offload in beam search model.
LLAMAConfig llama_config(model_config_file_path);
llama_config.print();
Expand Down Expand Up @@ -104,7 +105,7 @@ void LLAMA::create_llama_model(FFModel &ff,
llama_config.hidden_size / llama_config.num_attention_heads,
llama_config.hidden_size / llama_config.num_attention_heads,
0.0f, /*dropout*/
false, /*qkv_bias*/
qkv_bias, /*qkv_bias*/
false, /*final_bias*/
false, /*add_zero_attn*/
DT_NONE, /*data_type*/
Expand All @@ -129,7 +130,7 @@ void LLAMA::create_llama_model(FFModel &ff,
llama_config.hidden_size / llama_config.num_attention_heads,
llama_config.hidden_size / llama_config.num_attention_heads,
0.0f, /*dropout*/
false, /*qkv_bias*/
qkv_bias, /*qkv_bias*/
false, /*final_bias*/
false, /*add_zero_attn*/
DT_NONE, /*data_type*/
Expand All @@ -153,7 +154,7 @@ void LLAMA::create_llama_model(FFModel &ff,
llama_config.hidden_size / llama_config.num_attention_heads,
llama_config.hidden_size / llama_config.num_attention_heads,
0.0f, /*dropout*/
false, /*qkv_bias*/
qkv_bias, /*qkv_bias*/
false, /*final_bias*/
false, /*add_zero_attn*/
DT_NONE, /*data_type*/
Expand Down Expand Up @@ -188,34 +189,22 @@ void LLAMA::create_llama_model(FFModel &ff,
token = token_ff_norm[0];
Tensor ff_norm = token_ff_norm[1];

Tensor w1 = ff.dense(
ff_norm,
llama_config.intermediate_size,
AC_MODE_NONE,
false,
DT_NONE,
nullptr,
nullptr,
nullptr,
REG_MODE_NONE,
0.0f,
std::string("layers." + std::to_string(i) + ".mlp.gate_proj").c_str());

Tensor w3 = ff.dense(
ff_norm,
llama_config.intermediate_size,
AC_MODE_NONE,
false,
DT_NONE,
nullptr,
nullptr,
nullptr,
REG_MODE_NONE,
0.0f,
std::string("layers." + std::to_string(i) + ".mlp.up_proj").c_str());
Tensor hidden_gate_and_up =
ff.dense(ff_norm,
llama_config.intermediate_size * 2,
AC_MODE_NONE,
false,
DT_NONE,
nullptr,
nullptr,
nullptr,
REG_MODE_NONE,
0.0f,
std::string("layers." + std::to_string(i) + ".mlp.gate_and_up")
.c_str());

Tensor multi =
ff.sigmoid_silu_multi(w1, w3, llama_config.intermediate_size);
Tensor multi = ff.sigmoid_silu_multi(hidden_gate_and_up,
llama_config.intermediate_size);

w2 = ff.dense(
multi,
Expand Down
3 changes: 2 additions & 1 deletion inference/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class LLAMA {
InferenceMode mode,
GenerationConfig generation_config,
bool streaming_cache,
bool use_full_precision = false);
bool use_full_precision = false,
bool qkv_bias = false);
};

}; // namespace FlexFlow
15 changes: 11 additions & 4 deletions inference/spec_infer/spec_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ struct ModelMeta {
std::vector<ModelType> ssm_model_types;
std::vector<std::string> ssm_model_config_paths;
std::vector<std::string> ssm_model_weights_paths;

bool qkv_bias = false;
};

void parse_input_args(char **argv,
Expand Down Expand Up @@ -288,8 +290,11 @@ void get_model_meta(FilePaths &file_paths,
auto architectures = llm_model_config["architectures"];
for (auto const &str : architectures) {
if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" ||
str == "MistralForCausalLM") {
str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") {
model_metadata.llm_model_type = ModelType::LLAMA;
if (str == "Qwen2ForCausalLM") {
model_metadata.qkv_bias = true;
}
break;
} else if (str == "OPTForCausalLM") {
model_metadata.llm_model_type = ModelType::OPT;
Expand Down Expand Up @@ -350,7 +355,7 @@ void get_model_meta(FilePaths &file_paths,
auto architectures = ssm_model_config["architectures"];
for (auto const &str : architectures) {
if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" ||
str == "MistralForCausalLM") {
str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") {
ssm_model_type = ModelType::LLAMA;
break;
} else if (str == "OPTForCausalLM") {
Expand Down Expand Up @@ -525,7 +530,8 @@ void FlexFlow::top_level_task(Task const *task,
TREE_VERIFY_MODE,
generationConfig,
false,
use_full_precision);
use_full_precision,
/*qkv_bias*/ model_metadata.qkv_bias);
} else if (model_metadata.llm_model_type == ModelType::OPT) {
OPT::create_opt_model(tree_model,
model_metadata.llm_model_config_path,
Expand Down Expand Up @@ -574,7 +580,8 @@ void FlexFlow::top_level_task(Task const *task,
TREE_SEARCH_MODE,
generationConfig,
streaming_cache,
use_full_precision);
use_full_precision,
/*qkv_bias*/ model_metadata.qkv_bias);
} else if (model_metadata.ssm_model_types[ssm_id] == ModelType::OPT) {
OPT::create_opt_model(beam_model,
model_metadata.ssm_model_config_paths[ssm_id],
Expand Down
10 changes: 4 additions & 6 deletions inference/trace_generator/trace_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void get_model_meta(FilePaths &file_paths,
auto architectures = llm_model_config["architectures"];
for (auto const &str : architectures) {
if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" ||
str == "MistralForCausalLM") {
str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") {
model_metadata.llm_model_type = ModelType::LLAMA;
break;
} else if (str == "OPTForCausalLM") {
Expand Down Expand Up @@ -275,7 +275,7 @@ void get_model_meta(FilePaths &file_paths,
auto architectures = ssm_model_config["architectures"];
for (auto const &str : architectures) {
if (str == "LlamaForCausalLM" || str == "LLaMAForCausalLM" ||
str == "MistralForCausalLM") {
str == "MistralForCausalLM" || str == "Qwen2ForCausalLM") {
ssm_model_type = ModelType::LLAMA;
break;
} else if (str == "OPTForCausalLM") {
Expand Down Expand Up @@ -336,8 +336,6 @@ void FlexFlow::top_level_task(Task const *task,
int max_tokens_per_ssm_batch = -1;
int max_tokens_per_prefilling_batch = -1;
int expansion_degree = 3;
int max_tree_depth = 8;
int max_tree_width = 16;
RequestManager::DecodingMode decoding_mode =
RequestManager::SPECULATIVE_DECODING;
bool spec_sampling = false;
Expand Down Expand Up @@ -405,8 +403,8 @@ void FlexFlow::top_level_task(Task const *task,
rm->set_max_tokens_per_prefilling_batch(max_tokens_per_prefilling_batch);
rm->set_max_sequence_length(max_sequence_length);
rm->set_max_output_length(max_output_length);
rm->set_max_tree_depth(max_tree_depth);
rm->set_max_tree_width(max_tree_width);
rm->set_max_tree_depth(2);
rm->set_max_tree_width(2);
rm->set_verbose(verbose);
rm->set_streaming_cache(streaming_cache);
rm->register_tokenizer(model_metadata.llm_model_type,
Expand Down
Loading
Loading