diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -64,8 +64,10 @@ /// the highest lane which participates in the shuffle). /// /// %one = llvm.constant(1 : i32) : i32 - /// %shl = llvm.shl %one, %width : i32 - /// %active_mask = llvm.sub %shl, %one : i32 + /// %minus_one = llvm.constant(-1 : i32) : i32 + /// %thirty_two = llvm.constant(32 : i32) : i32 + /// %num_lanes = llvm.sub %thirty_two, %width : i32 + /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32 /// %mask_and_clamp = llvm.sub %width, %one : i32 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset, /// %mask_and_clamp : !llvm<"{ float, i1 }"> @@ -86,14 +88,24 @@ Value one = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(1)); - // Bit mask of active lanes: `(1 << activeWidth) - 1`. - Value activeMask = rewriter.create( - loc, int32Type, - rewriter.create(loc, int32Type, one, adaptor.width()), - one); - // Clamp lane: `activeWidth - 1` - Value maskAndClamp = - rewriter.create(loc, int32Type, adaptor.width(), one); + Value minusOne = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(-1)); + Value thirtyTwo = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(32)); + Value numActiveLanes = rewriter.create( + loc, int32Type, thirtyTwo, adaptor.width()); + // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`. + Value activeMask = + rewriter.create(loc, int32Type, minusOne, numActiveLanes); + Value maskAndClamp; + if (op.mode() == gpu::ShuffleMode::UP) { + // Clamp lane: `32 - activeWidth` + maskAndClamp = numActiveLanes; + } else { + // Clamp lane: `activeWidth - 1` + maskAndClamp = + rewriter.create(loc, int32Type, adaptor.width(), one); + } auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); Value shfl = rewriter.create( diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -117,14 +117,23 @@ // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : i32 %arg2 = arith.constant 23 : i32 // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[#SHL:]] = llvm.shl %[[#ONE]], %[[#WIDTH]] : i32 - // CHECK: %[[#MASK:]] = llvm.sub %[[#SHL]], %[[#ONE]] : i32 + // CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32 + // CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32 // CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : i32 // CHECK: %[[#SHFL:]] = nvvm.shfl.sync bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> // CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm.struct<(f32, i1)> // CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm.struct<(f32, i1)> %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : f32 - // CHECK: nvvm.shfl.sync up {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> + // CHECK: %[[#ONE:]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[#MINUS_ONE:]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK: %[[#THIRTY_TWO:]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK: %[[#NUM_LANES:]] = llvm.sub %[[#THIRTY_TWO]], %[[#WIDTH]] : i32 + // CHECK: %[[#MASK:]] = llvm.lshr %[[#MINUS_ONE]], %[[#NUM_LANES]] : i32 + // CHECK: %[[#SHFL:]] = nvvm.shfl.sync up %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#NUM_LANES]] {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> + // CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm.struct<(f32, i1)> + // CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm.struct<(f32, i1)> %shflu, %predu = gpu.shuffle up %arg0, %arg1, %arg2 : f32 // CHECK: nvvm.shfl.sync down {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)> %shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32