@@ -1145,3 +1145,114 @@ def test_quantized_relu(
1145
1145
torch .equal (output , expected_output ),
1146
1146
f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1147
1147
)
1148
+
1149
+ @expand (
1150
+ [
1151
+ (
1152
+ "basic_2d" ,
1153
+ torch .tensor (
1154
+ [[1.0 , 2.0 , 3.0 , 4.0 ]], dtype = torch .float32
1155
+ ), # input: [1, 4]
1156
+ torch .tensor (
1157
+ [[0.0 , 0.0 ]], dtype = torch .float32
1158
+ ), # sin: [1, 2] - broadcasts to [1, 4]
1159
+ torch .tensor (
1160
+ [[1.0 , 1.0 ]], dtype = torch .float32
1161
+ ), # cos: [1, 2] - broadcasts to [1, 4]
1162
+ torch .tensor (
1163
+ [[1.0 , 3.0 , 2.0 , 4.0 ]], dtype = torch .float32
1164
+ ), # expected: [1, 3, 2, 4]
1165
+ ),
1166
+ (
1167
+ "batch_sequence_3d" ,
1168
+ torch .tensor (
1169
+ [[[1.0 , 0.0 , 2.0 , 0.0 ]]], dtype = torch .float32
1170
+ ), # input: [1, 1, 4]
1171
+ torch .tensor ([[[0.5 , 0.5 ]]], dtype = torch .float32 ), # sin: [1, 1, 2]
1172
+ torch .tensor (
1173
+ [[[0.866 , 0.866 ]]], dtype = torch .float32
1174
+ ), # cos: [1, 1, 2] (approx cos(30°))
1175
+ torch .tensor (
1176
+ [[[0.866 , 1.732 , 0.5 , 1.0 ]]], dtype = torch .float32
1177
+ ), # expected: [0.866, 1.732, 0.5, 1.0]
1178
+ ),
1179
+ (
1180
+ "multiple_batch" ,
1181
+ torch .tensor (
1182
+ [[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = torch .float32
1183
+ ), # input: [2, 2]
1184
+ torch .tensor (
1185
+ [[0.0 ], [1.0 ]], dtype = torch .float32
1186
+ ), # sin: [2, 1] - broadcasts to [2, 2]
1187
+ torch .tensor (
1188
+ [[1.0 ], [0.0 ]], dtype = torch .float32
1189
+ ), # cos: [2, 1] - broadcasts to [2, 2]
1190
+ torch .tensor (
1191
+ [[1.0 , 2.0 ], [- 4.0 , 3.0 ]], dtype = torch .float32
1192
+ ), # expected: [[1, 2], [-4, 3]]
1193
+ ),
1194
+ (
1195
+ "larger_embedding" ,
1196
+ torch .tensor (
1197
+ [[1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]], dtype = torch .float32
1198
+ ), # input: [1, 6]
1199
+ torch .tensor ([[0.0 , 0.5 , 1.0 ]], dtype = torch .float32 ), # sin: [1, 3]
1200
+ torch .tensor ([[1.0 , 0.866 , 0.0 ]], dtype = torch .float32 ), # cos: [1, 3]
1201
+ torch .tensor (
1202
+ [[1.0 , 0.598 , - 6.0 , 2.0 , 4.964 , 5.0 ]], dtype = torch .float32
1203
+ ), # expected: [1, 0.598, -6, 2, 4.964, 5]
1204
+ ),
1205
+ (
1206
+ "single_pair" ,
1207
+ torch .tensor ([[1.0 , 2.0 ]], dtype = torch .float32 ), # input: [1, 2]
1208
+ torch .tensor ([[0.707 ]], dtype = torch .float32 ), # sin: [1, 1] (sin(45°))
1209
+ torch .tensor ([[0.707 ]], dtype = torch .float32 ), # cos: [1, 1] (cos(45°))
1210
+ torch .tensor (
1211
+ [[- 0.707 , 2.121 ]], dtype = torch .float32
1212
+ ), # expected: [-0.707, 2.121]
1213
+ ),
1214
+ (
1215
+ "pos is not None" ,
1216
+ torch .tensor (0 ),
1217
+ torch .tensor (0 ),
1218
+ torch .tensor (0 ),
1219
+ torch .tensor (0 ),
1220
+ torch .tensor (0 ), # pos is not None
1221
+ ),
1222
+ ]
1223
+ )
1224
+ def test_rope (
1225
+ self ,
1226
+ name : str ,
1227
+ input_tensor : torch .Tensor ,
1228
+ sin_tensor : torch .Tensor ,
1229
+ cos_tensor : torch .Tensor ,
1230
+ expected_output : torch .Tensor ,
1231
+ pos : torch .Tensor | None = None ,
1232
+ ) -> None :
1233
+ if pos is not None :
1234
+ with self .assertRaises (ValueError ) as context :
1235
+ torch .ops .cadence .rope (input_tensor , sin_tensor , cos_tensor , pos )
1236
+
1237
+ self .assertIn ("pos is not supported" , str (context .exception ))
1238
+ return
1239
+
1240
+ output = torch .ops .cadence .rope (input_tensor , sin_tensor , cos_tensor , None )
1241
+
1242
+ # Verify output properties
1243
+ self .assertEqual (
1244
+ output .dtype ,
1245
+ input_tensor .dtype ,
1246
+ f"Output dtype should match input dtype in { name } " ,
1247
+ )
1248
+ self .assertEqual (
1249
+ output .shape ,
1250
+ input_tensor .shape ,
1251
+ f"Output shape should match input shape in { name } " ,
1252
+ )
1253
+
1254
+ # Verify output matches expected values
1255
+ self .assertTrue (
1256
+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1257
+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1258
+ )
0 commit comments