@@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx,
3030                     const  paddle::optional<DenseTensor>& v,
3131                     const  paddle::optional<DenseTensor>& sin,
3232                     const  paddle::optional<DenseTensor>& cos,
33+                      const  paddle::optional<DenseTensor>& position_ids,
34+                      bool  use_neox_rotary_style,
3335                     DenseTensor* out_q,
3436                     DenseTensor* out_k,
3537                     DenseTensor* out_v) {
@@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx,
5961  phi::Array<T*, 3 > outs_data;
6062  phi::Array<const  T*, 3 > ins_data;
6163  phi::Array<const  T*, 2 > sin_cos_data;
64+   const  int64_t * position_ids_data = NULL ;
6265
6366  ins_data[0 ] = q.data <T>();
6467  outs_data[0 ] = out_q->data <T>();
@@ -109,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx,
109112              " The batch_size and num_heads of sin and cos must be 1." 
110113    }
111114    int  sin_seq_len_dim = (dims_size) == 4  ? 1  : 0 ;
112-     PADDLE_ENFORCE_EQ ((sin_dims[dims_size - 1 ] == head_dim &&
113-                        sin_dims[sin_seq_len_dim] == seq_len),
114-                       true ,
115-                       phi::errors::InvalidArgument (
116-                           " The seq_len and head_dim of sin and cos " 
117-                           " must be the same as those of q. But recieved sin's " 
118-                           " shape is {%s}, q's shape is {%s}." 
119-                           sin_dims,
120-                           q.dims ()));
115+ 
116+     if  (position_ids.get_ptr ()) {
117+       PADDLE_ENFORCE_EQ (
118+           (sin_dims[dims_size - 1 ] == head_dim &&
119+            sin_dims[sin_seq_len_dim] >= seq_len),
120+           true ,
121+           phi::errors::InvalidArgument (
122+               " The seq_len of sin and cos must be greater than or equal to " 
123+               " this of q. The head_dim of sin and cos must be the same as this " 
124+               " of q. But recieved sin's " 
125+               " shape is {%s}, q's shape is {%s}." 
126+               sin_dims,
127+               q.dims ()));
128+ 
129+       auto  position_ids_dims = position_ids.get_ptr ()->dims ();
130+       PADDLE_ENFORCE_EQ (position_ids_dims.size (),
131+                         2 ,
132+                         phi::errors::InvalidArgument (
133+                             " The dims of position_ids is expected to " 
134+                             " be 2, but recieved %d." 
135+                             position_ids_dims.size ()));
136+ 
137+       PADDLE_ENFORCE_EQ (
138+           (position_ids_dims[0 ] == batch_size &&
139+            position_ids_dims[1 ] == seq_len),
140+           true ,
141+           phi::errors::InvalidArgument (
142+               " The batch_size and seq_len of position_ids must be the same as " 
143+               " those of q. But recieved position_ids's " 
144+               " shape is {%s}, q's shape is {%s}." 
145+               position_ids_dims,
146+               q.dims ()));
147+ 
148+       position_ids_data = position_ids->data <int64_t >();
149+     } else  {
150+       PADDLE_ENFORCE_EQ (
151+           (sin_dims[dims_size - 1 ] == head_dim &&
152+            sin_dims[sin_seq_len_dim] == seq_len),
153+           true ,
154+           phi::errors::InvalidArgument (
155+               " The seq_len and head_dim of sin and cos " 
156+               " must be the same as those of q. But recieved sin's " 
157+               " shape is {%s}, q's shape is {%s}." 
158+               sin_dims,
159+               q.dims ()));
160+     }
121161
122162    sin_cos_data[0 ] = sin->data <T>();
123163    sin_cos_data[1 ] = cos->data <T>();
@@ -126,18 +166,35 @@ void FusedRopeKernel(const Context& dev_ctx,
126166  }
127167
128168  int  sign = 1 ;
129-   VectorizedFusedRopeKernel<T, MPType, vec_size>
130-       <<<grid, block, 0 , stream>>> (ins_data,
131-                                    sin_cos_data,
132-                                    flag_sin_cos,
133-                                    sign,
134-                                    batch_size,
135-                                    seq_len,
136-                                    num_heads,
137-                                    head_dim,
138-                                    outs_data,
139-                                    num_inputs,
140-                                    div_c);
169+   if  (use_neox_rotary_style) {
170+     VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
171+         <<<grid, block, 0 , stream>>> (ins_data,
172+                                      sin_cos_data,
173+                                      position_ids_data,
174+                                      flag_sin_cos,
175+                                      sign,
176+                                      batch_size,
177+                                      seq_len,
178+                                      num_heads,
179+                                      head_dim,
180+                                      outs_data,
181+                                      num_inputs,
182+                                      div_c);
183+   } else  {
184+     VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
185+         <<<grid, block, 0 , stream>>> (ins_data,
186+                                      sin_cos_data,
187+                                      position_ids_data,
188+                                      flag_sin_cos,
189+                                      sign,
190+                                      batch_size,
191+                                      seq_len,
192+                                      num_heads,
193+                                      head_dim,
194+                                      outs_data,
195+                                      num_inputs,
196+                                      div_c);
197+   }
141198}
142199}  //  namespace fusion
143200}  //  namespace phi
0 commit comments