Skip to content

Commit 066b5e3

Browse files
jayfoadPravin Jagtap
authored andcommitted
[EarlyCSE] Do not CSE convergent calls with memory effects
D149348 did this for readnone calls, which are handled by SimpleValue. This patch does the same for all other CSEable calls, which are handled by CallValue. Differential Revision: https://reviews.llvm.org/D153151 Change-Id: Ied78587d48f12d8735f789a73f75c1c1c010618d
1 parent ac328a9 commit 066b5e3

File tree

2 files changed

+110
-10
lines changed

2 files changed

+110
-10
lines changed

llvm/lib/Transforms/Scalar/EarlyCSE.cpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,19 @@ static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A,
218218
return true;
219219
}
220220

221+
static unsigned hashCallInst(CallInst *CI) {
222+
// Don't CSE convergent calls in different basic blocks, because they
223+
// implicitly depend on the set of threads that is currently executing.
224+
if (CI->isConvergent()) {
225+
return hash_combine(
226+
CI->getOpcode(), CI->getParent(),
227+
hash_combine_range(CI->value_op_begin(), CI->value_op_end()));
228+
}
229+
return hash_combine(
230+
CI->getOpcode(),
231+
hash_combine_range(CI->value_op_begin(), CI->value_op_end()));
232+
}
233+
221234
static unsigned getHashValueImpl(SimpleValue Val) {
222235
Instruction *Inst = Val.Inst;
223236
// Hash in all of the operands as pointers.
@@ -320,11 +333,8 @@ static unsigned getHashValueImpl(SimpleValue Val) {
320333

321334
// Don't CSE convergent calls in different basic blocks, because they
322335
// implicitly depend on the set of threads that is currently executing.
323-
if (CallInst *CI = dyn_cast<CallInst>(Inst); CI && CI->isConvergent()) {
324-
return hash_combine(
325-
Inst->getOpcode(), Inst->getParent(),
326-
hash_combine_range(Inst->value_op_begin(), Inst->value_op_end()));
327-
}
336+
if (CallInst *CI = dyn_cast<CallInst>(Inst))
337+
return hashCallInst(CI);
328338

329339
// Mix in the opcode.
330340
return hash_combine(
@@ -524,15 +534,21 @@ unsigned DenseMapInfo<CallValue>::getHashValue(CallValue Val) {
524534
Instruction *Inst = Val.Inst;
525535

526536
// Hash all of the operands as pointers and mix in the opcode.
527-
return hash_combine(
528-
Inst->getOpcode(),
529-
hash_combine_range(Inst->value_op_begin(), Inst->value_op_end()));
537+
return hashCallInst(cast<CallInst>(Inst));
530538
}
531539

532540
bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) {
533-
Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst;
534541
if (LHS.isSentinel() || RHS.isSentinel())
535-
return LHSI == RHSI;
542+
return LHS.Inst == RHS.Inst;
543+
544+
CallInst *LHSI = cast<CallInst>(LHS.Inst);
545+
CallInst *RHSI = cast<CallInst>(RHS.Inst);
546+
547+
// Convergent calls implicitly depend on the set of threads that is
548+
// currently executing, so conservatively return false if they are in
549+
// different basic blocks.
550+
if (LHSI->isConvergent() && LHSI->getParent() != RHSI->getParent())
551+
return false;
536552

537553
return LHSI->isIdenticalTo(RHSI);
538554
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
2+
; RUN: opt < %s -S -mtriple=amdgcn-- -passes=early-cse -earlycse-debug-hash | FileCheck %s
3+
4+
; Should not CSE calls marked as convergent, even if the callee is not convergent.
5+
6+
define i32 @test_read_register(i32 %cond) {
7+
; CHECK-LABEL: define i32 @test_read_register
8+
; CHECK-SAME: (i32 [[COND:%.*]]) {
9+
; CHECK-NEXT: entry:
10+
; CHECK-NEXT: [[X1:%.*]] = call i32 @llvm.read_register.i32(metadata [[META0:![0-9]+]]) #[[ATTR2:[0-9]+]]
11+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[COND]], 0
12+
; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]]
13+
; CHECK: if:
14+
; CHECK-NEXT: [[Y1:%.*]] = call i32 @llvm.read_register.i32(metadata [[META0]]) #[[ATTR2]]
15+
; CHECK-NEXT: br label [[END]]
16+
; CHECK: end:
17+
; CHECK-NEXT: [[Y2:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[Y1]], [[IF]] ]
18+
; CHECK-NEXT: [[RET:%.*]] = add i32 [[X1]], [[Y2]]
19+
; CHECK-NEXT: ret i32 [[RET]]
20+
;
21+
entry:
22+
; %x = ballot operation over all lanes.
23+
%x1 = call i32 @llvm.read_register.i32(metadata !{!"exec_lo"}) convergent
24+
%cmp = icmp eq i32 %cond, 0
25+
br i1 %cmp, label %if, label %end
26+
27+
if:
28+
; %y = ballot operation over lanes satisfying %cond.
29+
%y1 = call i32 @llvm.read_register.i32(metadata !{!"exec_lo"}) convergent
30+
br label %end
31+
32+
end:
33+
%y2 = phi i32 [0, %entry], [%y1, %if]
34+
%ret = add i32 %x1, %y2
35+
ret i32 %ret
36+
}
37+
38+
define i32 @test_read_register_samebb(i32 %cond) {
39+
; CHECK-LABEL: define i32 @test_read_register_samebb
40+
; CHECK-SAME: (i32 [[COND:%.*]]) {
41+
; CHECK-NEXT: entry:
42+
; CHECK-NEXT: [[X:%.*]] = call i32 @llvm.read_register.i32(metadata [[META0]]) #[[ATTR2]]
43+
; CHECK-NEXT: [[RET:%.*]] = add i32 [[X]], [[X]]
44+
; CHECK-NEXT: ret i32 [[RET]]
45+
;
46+
entry:
47+
%x = call i32 @llvm.read_register.i32(metadata !{!"exec_lo"}) convergent
48+
%y = call i32 @llvm.read_register.i32(metadata !{!"exec_lo"}) convergent
49+
%ret = add i32 %x, %y
50+
ret i32 %ret
51+
}
52+
53+
define i1 @test_live_mask(i32 %cond) {
54+
; CHECK-LABEL: define i1 @test_live_mask
55+
; CHECK-SAME: (i32 [[COND:%.*]]) {
56+
; CHECK-NEXT: entry:
57+
; CHECK-NEXT: [[X1:%.*]] = call i1 @llvm.amdgcn.live.mask() #[[ATTR2]]
58+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[COND]], 0
59+
; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]]
60+
; CHECK: if:
61+
; CHECK-NEXT: [[Y1:%.*]] = call i1 @llvm.amdgcn.live.mask() #[[ATTR2]]
62+
; CHECK-NEXT: br label [[END]]
63+
; CHECK: end:
64+
; CHECK-NEXT: [[Y2:%.*]] = phi i1 [ false, [[ENTRY:%.*]] ], [ [[Y1]], [[IF]] ]
65+
; CHECK-NEXT: [[RET:%.*]] = add i1 [[X1]], [[Y2]]
66+
; CHECK-NEXT: ret i1 [[RET]]
67+
;
68+
entry:
69+
%x1 = call i1 @llvm.amdgcn.live.mask() convergent
70+
%cmp = icmp eq i32 %cond, 0
71+
br i1 %cmp, label %if, label %end
72+
73+
if:
74+
%y1 = call i1 @llvm.amdgcn.live.mask() convergent
75+
br label %end
76+
77+
end:
78+
%y2 = phi i1 [0, %entry], [%y1, %if]
79+
%ret = add i1 %x1, %y2
80+
ret i1 %ret
81+
}
82+
83+
declare i32 @llvm.read_register.i32(metadata)
84+
declare i1 @llvm.amdgcn.live.mask()

0 commit comments

Comments
 (0)