Skip to content

Commit ed77891

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Rope custom op (#14399)
Summary: Continued support of cadence custom ops Differential Revision: D82702247
1 parent 654e722 commit ed77891

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,3 +1092,19 @@ def rms_norm(
10921092
eps: float,
10931093
) -> torch.Tensor:
10941094
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)
1095+
1096+
1097+
@impl(m, "rope")
1098+
def rope(
1099+
input_tensor: torch.Tensor,
1100+
sin_tensor: torch.Tensor,
1101+
cos_tensor: torch.Tensor,
1102+
pos: torch.Tensor | None,
1103+
) -> torch.Tensor:
1104+
if pos is not None:
1105+
raise ValueError("pos is not supported")
1106+
x0, x1 = input_tensor[..., ::2], input_tensor[..., 1::2]
1107+
rotated = torch.cat(
1108+
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
1109+
)
1110+
return rotated

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,3 +1145,114 @@ def test_quantized_relu(
11451145
torch.equal(output, expected_output),
11461146
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
11471147
)
1148+
1149+
@expand(
1150+
[
1151+
(
1152+
"basic_2d",
1153+
torch.tensor(
1154+
[[1.0, 2.0, 3.0, 4.0]], dtype=torch.float32
1155+
), # input: [1, 4]
1156+
torch.tensor(
1157+
[[0.0, 0.0]], dtype=torch.float32
1158+
), # sin: [1, 2] - broadcasts to [1, 4]
1159+
torch.tensor(
1160+
[[1.0, 1.0]], dtype=torch.float32
1161+
), # cos: [1, 2] - broadcasts to [1, 4]
1162+
torch.tensor(
1163+
[[1.0, 3.0, 2.0, 4.0]], dtype=torch.float32
1164+
), # expected: [1, 3, 2, 4]
1165+
),
1166+
(
1167+
"batch_sequence_3d",
1168+
torch.tensor(
1169+
[[[1.0, 0.0, 2.0, 0.0]]], dtype=torch.float32
1170+
), # input: [1, 1, 4]
1171+
torch.tensor([[[0.5, 0.5]]], dtype=torch.float32), # sin: [1, 1, 2]
1172+
torch.tensor(
1173+
[[[0.866, 0.866]]], dtype=torch.float32
1174+
), # cos: [1, 1, 2] (approx cos(30°))
1175+
torch.tensor(
1176+
[[[0.866, 1.732, 0.5, 1.0]]], dtype=torch.float32
1177+
), # expected: [0.866, 1.732, 0.5, 1.0]
1178+
),
1179+
(
1180+
"multiple_batch",
1181+
torch.tensor(
1182+
[[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32
1183+
), # input: [2, 2]
1184+
torch.tensor(
1185+
[[0.0], [1.0]], dtype=torch.float32
1186+
), # sin: [2, 1] - broadcasts to [2, 2]
1187+
torch.tensor(
1188+
[[1.0], [0.0]], dtype=torch.float32
1189+
), # cos: [2, 1] - broadcasts to [2, 2]
1190+
torch.tensor(
1191+
[[1.0, 2.0], [-4.0, 3.0]], dtype=torch.float32
1192+
), # expected: [[1, 2], [-4, 3]]
1193+
),
1194+
(
1195+
"larger_embedding",
1196+
torch.tensor(
1197+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], dtype=torch.float32
1198+
), # input: [1, 6]
1199+
torch.tensor([[0.0, 0.5, 1.0]], dtype=torch.float32), # sin: [1, 3]
1200+
torch.tensor([[1.0, 0.866, 0.0]], dtype=torch.float32), # cos: [1, 3]
1201+
torch.tensor(
1202+
[[1.0, 0.598, -6.0, 2.0, 4.964, 5.0]], dtype=torch.float32
1203+
), # expected: [1, 0.598, -6, 2, 4.964, 5]
1204+
),
1205+
(
1206+
"single_pair",
1207+
torch.tensor([[1.0, 2.0]], dtype=torch.float32), # input: [1, 2]
1208+
torch.tensor([[0.707]], dtype=torch.float32), # sin: [1, 1] (sin(45°))
1209+
torch.tensor([[0.707]], dtype=torch.float32), # cos: [1, 1] (cos(45°))
1210+
torch.tensor(
1211+
[[-0.707, 2.121]], dtype=torch.float32
1212+
), # expected: [-0.707, 2.121]
1213+
),
1214+
(
1215+
"pos is not None",
1216+
torch.tensor(0),
1217+
torch.tensor(0),
1218+
torch.tensor(0),
1219+
torch.tensor(0),
1220+
torch.tensor(0), # pos is not None
1221+
),
1222+
]
1223+
)
1224+
def test_rope(
1225+
self,
1226+
name: str,
1227+
input_tensor: torch.Tensor,
1228+
sin_tensor: torch.Tensor,
1229+
cos_tensor: torch.Tensor,
1230+
expected_output: torch.Tensor,
1231+
pos: torch.Tensor | None = None,
1232+
) -> None:
1233+
if pos is not None:
1234+
with self.assertRaises(ValueError) as context:
1235+
torch.ops.cadence.rope(input_tensor, sin_tensor, cos_tensor, pos)
1236+
1237+
self.assertIn("pos is not supported", str(context.exception))
1238+
return
1239+
1240+
output = torch.ops.cadence.rope(input_tensor, sin_tensor, cos_tensor, None)
1241+
1242+
# Verify output properties
1243+
self.assertEqual(
1244+
output.dtype,
1245+
input_tensor.dtype,
1246+
f"Output dtype should match input dtype in {name}",
1247+
)
1248+
self.assertEqual(
1249+
output.shape,
1250+
input_tensor.shape,
1251+
f"Output shape should match input shape in {name}",
1252+
)
1253+
1254+
# Verify output matches expected values
1255+
self.assertTrue(
1256+
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
1257+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1258+
)

0 commit comments

Comments
 (0)