Index: llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp +++ llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp @@ -163,39 +163,29 @@ if (!GroupSize || !GridSize) continue; + using namespace llvm::PatternMatch; + auto GroupIDIntrin = + I == 0 ? m_Intrinsic() + : (I == 1 ? m_Intrinsic() + : m_Intrinsic()); + for (User *U : GroupSize->users()) { auto *ZextGroupSize = dyn_cast(U); if (!ZextGroupSize) continue; - for (User *ZextUser : ZextGroupSize->users()) { - auto *SI = dyn_cast(ZextUser); - if (!SI) - continue; - - using namespace llvm::PatternMatch; - auto GroupIDIntrin = I == 0 ? - m_Intrinsic() : - (I == 1 ? m_Intrinsic() : - m_Intrinsic()); - - auto SubExpr = m_Sub(m_Specific(GridSize), - m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))); - - ICmpInst::Predicate Pred; - if (match(SI, - m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)), - SubExpr, - m_Specific(ZextGroupSize))) && - Pred == ICmpInst::ICMP_ULT) { + for (User *UMin : ZextGroupSize->users()) { + if (match(UMin, + m_UMin(m_Sub(m_Specific(GridSize), + m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))), + m_Specific(ZextGroupSize)))) { if (HasReqdWorkGroupSize) { ConstantInt *KnownSize = mdconst::extract(MD->getOperand(I)); - SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize, - SI->getType(), - false)); + UMin->replaceAllUsesWith(ConstantExpr::getIntegerCast( + KnownSize, UMin->getType(), false)); } else { - SI->replaceAllUsesWith(ZextGroupSize); + UMin->replaceAllUsesWith(ZextGroupSize); } MadeChange = true; Index: llvm/test/CodeGen/AMDGPU/reqd-work-group-size.ll =================================================================== --- llvm/test/CodeGen/AMDGPU/reqd-work-group-size.ll +++ llvm/test/CodeGen/AMDGPU/reqd-work-group-size.ll @@ -96,9 +96,8 @@ %group.size.x.zext = zext i16 %group.size.x to i32 %group.id_x_group.size.x = mul i32 %group.id, %group.size.x.zext %sub = sub i32 %grid.size.x, %group.id_x_group.size.x - %cmp = icmp ult i32 %sub, %group.size.x.zext - %select = select i1 %cmp, i32 %sub, i32 %group.size.x.zext - %zext = zext i32 %select to i64 + %umin = call i32 @llvm.umin.i32(i32 %sub, i32 %group.size.x.zext) + %zext = zext i32 %umin to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -117,9 +116,8 @@ %group.size.y.zext = zext i16 %group.size.y to i32 %group.id_x_group.size.y = mul i32 %group.id, %group.size.y.zext %sub = sub i32 %grid.size.y, %group.id_x_group.size.y - %cmp = icmp ult i32 %sub, %group.size.y.zext - %select = select i1 %cmp, i32 %sub, i32 %group.size.y.zext - %zext = zext i32 %select to i64 + %umin = call i32 @llvm.umin.i32(i32 %sub, i32 %group.size.y.zext) + %zext = zext i32 %umin to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -138,9 +136,8 @@ %group.size.z.zext = zext i16 %group.size.z to i32 %group.id_x_group.size.z = mul i32 %group.id, %group.size.z.zext %sub = sub i32 %grid.size.z, %group.id_x_group.size.z - %cmp = icmp ult i32 %sub, %group.size.z.zext - %select = select i1 %cmp, i32 %sub, i32 %group.size.z.zext - %zext = zext i32 %select to i64 + %umin = call i32 @llvm.umin.i32(i32 %sub, i32 %group.size.z.zext) + %zext = zext i32 %umin to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -163,9 +160,8 @@ %group.size.x.zext = zext i16 %group.size.x to i32 %group.id_x_group.size.x = mul i32 %group.id, %group.size.x.zext %sub = sub i32 %grid.size.x, %group.id_x_group.size.x - %cmp = icmp ult i32 %sub, %group.size.x.zext - %select = select i1 %cmp, i32 %sub, i32 %group.size.x.zext - %zext = zext i32 %select to i64 + %umin = call i32 @llvm.umin.i32(i32 %sub, i32 %group.size.x.zext) + %zext = zext i32 %umin to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -186,9 +182,8 @@ %group.size.x.zext = zext i16 %group.size.x to i32 %group.id_x_group.size.x = mul i32 %group.id, %group.size.x.zext %sub = sub i32 %grid.size.x, %group.id_x_group.size.x - %cmp = icmp ult i32 %sub, %group.size.x.zext - %select = select i1 %cmp, i32 %sub, i32 %group.size.x.zext - %zext = zext i32 %select to i64 + %umin = call i32 @llvm.umin.i32(i32 %sub, i32 %group.size.x.zext) + %zext = zext i32 %umin to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -198,7 +193,7 @@ ; CHECK: %group.id = tail call i32 @llvm.amdgcn.workgroup.id.x() ; CHECK: %group.id_x_group.size.x.neg = mul i32 %group.id, -8 ; CHECK: %sub = add i32 %group.id_x_group.size.x.neg, %grid.size.x -; CHECK: %1 = call i32 @llvm.smin.i32(i32 %sub, i32 8) +; CHECK: %smin = call i32 @llvm.smin.i32(i32 %sub, i32 8) define amdgpu_kernel void @local_size_x_8_16_2_wrong_cmp_type(i64 addrspace(1)* %out) #0 !reqd_work_group_size !0 { %dispatch.ptr = tail call i8 addrspace(4)* @llvm.amdgcn.dispatch.ptr() %gep.group.size.x = getelementptr inbounds i8, i8 addrspace(4)* %dispatch.ptr, i64 4 @@ -211,9 +206,8 @@ %group.size.x.zext = zext i16 %group.size.x to i32 %group.id_x_group.size.x = mul i32 %group.id, %group.size.x.zext %sub = sub i32 %grid.size.x, %group.id_x_group.size.x - %cmp = icmp slt i32 %sub, %group.size.x.zext - %select = select i1 %cmp, i32 %sub, i32 %group.size.x.zext - %zext = zext i32 %select to i64 + %smin = call i32 @llvm.smin.i32(i32 %sub, i32 %group.size.x.zext) + %zext = zext i32 %smin to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -221,8 +215,8 @@ ; CHECK-LABEL: @local_size_x_8_16_2_wrong_select( ; CHECK: %group.id_x_group.size.x.neg = mul i32 %group.id, -8 ; CHECK: %sub = add i32 %group.id_x_group.size.x.neg, %grid.size.x -; CHECK: %1 = call i32 @llvm.umax.i32(i32 %sub, i32 8) -; CHECK: %zext = zext i32 %1 to i64 +; CHECK: %umax = call i32 @llvm.umax.i32(i32 %sub, i32 8) +; CHECK: %zext = zext i32 %umax to i64 define amdgpu_kernel void @local_size_x_8_16_2_wrong_select(i64 addrspace(1)* %out) #0 !reqd_work_group_size !0 { %dispatch.ptr = tail call i8 addrspace(4)* @llvm.amdgcn.dispatch.ptr() %gep.group.size.x = getelementptr inbounds i8, i8 addrspace(4)* %dispatch.ptr, i64 4 @@ -235,9 +229,8 @@ %group.size.x.zext = zext i16 %group.size.x to i32 %group.id_x_group.size.x = mul i32 %group.id, %group.size.x.zext %sub = sub i32 %grid.size.x, %group.id_x_group.size.x - %cmp = icmp ult i32 %sub, %group.size.x.zext - %select = select i1 %cmp, i32 %group.size.x.zext, i32 %sub - %zext = zext i32 %select to i64 + %umax = call i32 @llvm.umax.i32(i32 %sub, i32 %group.size.x.zext) + %zext = zext i32 %umax to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -261,9 +254,8 @@ %group.size.x.zext = zext i16 %group.size.x to i32 %group.id_x_group.size.x = mul i32 %group.id, %group.size.x.zext %sub = sub i32 %grid.size.x.zext, %group.id_x_group.size.x - %cmp = icmp ult i32 %sub, %group.size.x.zext - %select = select i1 %cmp, i32 %sub, i32 %group.size.x.zext - %zext = zext i32 %select to i64 + %umin = call i32 @llvm.umin.i32(i32 %sub, i32 %group.size.x.zext) + %zext = zext i32 %umin to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -327,9 +319,8 @@ %tmp29 = zext i16 %group.size to i32 %tmp30 = mul i32 %tmp28, %tmp29 %tmp31 = sub i32 %tmp26, %tmp30 - %tmp32 = icmp ult i32 %tmp31, %tmp29 - %tmp33 = select i1 %tmp32, i32 %tmp31, i32 %tmp29 - %tmp34 = zext i32 %tmp33 to i64 + %umin = call i32 @llvm.umin.i32(i32 %tmp31, i32 %tmp29) + %tmp34 = zext i32 %umin to i64 ret i64 %tmp34 } @@ -349,9 +340,8 @@ %tmp29.i = zext i16 %tmp8.i to i32 %tmp30.i = mul i32 %tmp2.i, %tmp29.i %tmp31.i = sub i32 %tmp5.i, %tmp30.i - %tmp32.i = icmp ult i32 %tmp31.i, %tmp29.i - %tmp33.i = select i1 %tmp32.i, i32 %tmp31.i, i32 %tmp29.i - %tmp34.i = zext i32 %tmp33.i to i64 + %umin0 = call i32 @llvm.umin.i32(i32 %tmp31.i, i32 %tmp29.i) + %tmp34.i = zext i32 %umin0 to i64 %tmp10.i = tail call i32 @llvm.amdgcn.workgroup.id.y() #0 %tmp11.i = getelementptr inbounds i8, i8 addrspace(4)* %tmp.i, i64 16 %tmp12.i = bitcast i8 addrspace(4)* %tmp11.i to i32 addrspace(4)* @@ -362,9 +352,8 @@ %tmp29.i9 = zext i16 %tmp16.i to i32 %tmp30.i10 = mul i32 %tmp10.i, %tmp29.i9 %tmp31.i11 = sub i32 %tmp13.i, %tmp30.i10 - %tmp32.i12 = icmp ult i32 %tmp31.i11, %tmp29.i9 - %tmp33.i13 = select i1 %tmp32.i12, i32 %tmp31.i11, i32 %tmp29.i9 - %tmp34.i14 = zext i32 %tmp33.i13 to i64 + %umin1 = call i32 @llvm.umin.i32(i32 %tmp31.i11, i32 %tmp29.i9) + %tmp34.i14 = zext i32 %umin1 to i64 %tmp18.i = tail call i32 @llvm.amdgcn.workgroup.id.z() #0 %tmp19.i = getelementptr inbounds i8, i8 addrspace(4)* %tmp.i, i64 20 %tmp20.i = bitcast i8 addrspace(4)* %tmp19.i to i32 addrspace(4)* @@ -375,9 +364,8 @@ %tmp29.i2 = zext i16 %tmp24.i to i32 %tmp30.i3 = mul i32 %tmp18.i, %tmp29.i2 %tmp31.i4 = sub i32 %tmp21.i, %tmp30.i3 - %tmp32.i5 = icmp ult i32 %tmp31.i4, %tmp29.i2 - %tmp33.i6 = select i1 %tmp32.i5, i32 %tmp31.i4, i32 %tmp29.i2 - %tmp34.i7 = zext i32 %tmp33.i6 to i64 + %umin2 = call i32 @llvm.umin.i32(i32 %tmp31.i4, i32 %tmp29.i2) + %tmp34.i7 = zext i32 %umin2 to i64 store volatile i64 %tmp34.i, i64 addrspace(1)* %out, align 4 store volatile i64 %tmp34.i14, i64 addrspace(1)* %out, align 4 store volatile i64 %tmp34.i7, i64 addrspace(1)* %out, align 4 @@ -462,9 +450,8 @@ %group.size.x.zext = zext i16 %group.size.x to i32 %group.id_x_group.size.x = mul i32 %group.id, %group.size.x.zext %sub = sub i32 %grid.size.x, %group.id_x_group.size.x - %cmp = icmp ult i32 %sub, %group.size.x.zext - %select = select i1 %cmp, i32 %sub, i32 %group.size.x.zext - %zext = zext i32 %select to i64 + %umin = call i32 @llvm.umin.i32(i32 %sub, i32 %group.size.x.zext) + %zext = zext i32 %umin to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -483,9 +470,8 @@ %group.size.x.zext = zext i16 %group.size.x to i32 %group.id_x_group.size.x = mul i32 %group.id, %group.size.x.zext %sub = sub i32 %grid.size.x, %group.id_x_group.size.x - %cmp = icmp ult i32 %sub, %group.size.x.zext - %select = select i1 %cmp, i32 %sub, i32 %group.size.x.zext - %zext = zext i32 %select to i64 + %umin = call i32 @llvm.umin.i32(i32 %sub, i32 %group.size.x.zext) + %zext = zext i32 %umin to i64 store i64 %zext, i64 addrspace(1)* %out ret void } @@ -501,6 +487,9 @@ declare i32 @llvm.amdgcn.workgroup.id.x() #1 declare i32 @llvm.amdgcn.workgroup.id.y() #1 declare i32 @llvm.amdgcn.workgroup.id.z() #1 +declare i32 @llvm.umin.i32(i32, i32) #1 +declare i32 @llvm.smin.i32(i32, i32) #1 +declare i32 @llvm.umax.i32(i32, i32) #1 attributes #0 = { nounwind "uniform-work-group-size"="true" } attributes #1 = { nounwind readnone speculatable }