Skip to content

Commit e6b9111

Browse files
authored
Add logical and op (pytorch#13342)
Summary: As title, add the logical and op for an internal model Rollback Plan: Differential Revision: D80122607
1 parent 378c700 commit e6b9111

File tree

5 files changed

+26
-2
lines changed

5 files changed

+26
-2
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class LayoutTransform(ExportPass):
9393
exir_ops.edge.aten.le.Tensor,
9494
exir_ops.edge.aten.linear.default,
9595
exir_ops.edge.aten.log.default,
96+
exir_ops.edge.aten.logical_and.default,
9697
exir_ops.edge.aten.logical_not.default,
9798
exir_ops.edge.aten.lt.Scalar,
9899
exir_ops.edge.aten.lt.Tensor,

backends/qualcomm/builders/op_and.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class OpAnd(NodeVisitor):
19-
target = ["aten.bitwise_and.Tensor"]
19+
target = ["aten.bitwise_and.Tensor", "aten.logical_and.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/quantizer/annotators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non
839839
)
840840

841841

842-
@register_annotator([torch.ops.aten.__and__.Tensor])
842+
@register_annotator([torch.ops.aten.__and__.Tensor, torch.ops.aten.logical_and.default])
843843
def annotate_and(node: Node, quantization_config: QuantizationConfig) -> None:
844844
annotate_binary(node, quantization_config)
845845

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,14 @@ def forward(self, x):
12231223
return torch.log(x)
12241224

12251225

1226+
class LogicalAnd(torch.nn.Module):
1227+
def __init__(self):
1228+
super().__init__()
1229+
1230+
def forward(self, x, y):
1231+
return torch.logical_and(x != 0, y != 0).float()
1232+
1233+
12261234
class LogicalNot(torch.nn.Module):
12271235
def __init__(self):
12281236
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,13 @@ def test_qnn_backend_log(self):
924924
sample_input = (torch.rand([1, 2, 3, 4]),)
925925
self.lower_module_and_test_output(module, sample_input)
926926

927+
def test_qnn_backend_logical_and(self):
928+
module = LogicalAnd() # noqa: F405
929+
input1 = torch.tensor([True, False, True, False])
930+
input2 = torch.tensor([True, True, False, False])
931+
sample_input = (input1, input2)
932+
self.lower_module_and_test_output(module, sample_input)
933+
927934
def test_qnn_backend_logical_not(self):
928935
module = LogicalNot() # noqa: F405
929936
sample_input = (torch.rand([1, 2, 3, 4]),)
@@ -2484,6 +2491,14 @@ def test_qnn_backend_log(self):
24842491
module = self.get_qdq_module(module, sample_input)
24852492
self.lower_module_and_test_output(module, sample_input)
24862493

2494+
def test_qnn_backend_logical_and(self):
2495+
module = LogicalAnd() # noqa: F405
2496+
input1 = torch.tensor([0.0])
2497+
input2 = torch.tensor([1.0])
2498+
sample_input = (input1, input2)
2499+
module = self.get_qdq_module(module, sample_input)
2500+
self.lower_module_and_test_output(module, sample_input)
2501+
24872502
def test_qnn_backend_logical_not(self):
24882503
module = LogicalNot() # noqa: F405
24892504
sample_input = (torch.rand([1, 2, 3, 4]),)

0 commit comments

Comments
 (0)