diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -54,6 +54,10 @@ LLVM_OpBase { } +class ROCDL_IntrPure1Op : + LLVM_IntrOpBase; + //===----------------------------------------------------------------------===// // ROCDL special register op definitions //===----------------------------------------------------------------------===// @@ -77,6 +81,19 @@ let assemblyFormat = "attr-dict `:` type($res)"; } +//===----------------------------------------------------------------------===// +// Wave-level primitives + +class ROCDL_MbcntOp : + ROCDL_IntrPure1Op<"mbcnt." # mnemonic>, + Arguments<(ins Variadic:$args)> { + let assemblyFormat = + "$args attr-dict `:` functional-type($args, $res)"; +} + +def ROCDL_MbcntLoOp : ROCDL_MbcntOp<"lo">; +def ROCDL_MbcntHiOp : ROCDL_MbcntOp<"hi">; + //===----------------------------------------------------------------------===// // Thread index and Block index diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -60,6 +60,38 @@ } namespace { +struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *context = rewriter.getContext(); + // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0) + // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) + + Type intTy = IntegerType::get(context, 32); + Value zero = rewriter.createOrFold(loc, 0, 32); + Value minus1 = rewriter.createOrFold(loc, -1, 32); + Value mbcntLo = + rewriter.create(loc, intTy, ValueRange{minus1, zero}); + Value laneId = rewriter.create( + loc, intTy, ValueRange{minus1, mbcntLo}); + // Truncate or extend the result depending on the index bitwidth specified + // by the LLVMTypeConverter options. + const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); + if (indexBitwidth > 32) { + laneId = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), laneId); + } else if (indexBitwidth < 32) { + laneId = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), laneId); + } + rewriter.replaceOp(op, {laneId}); + return success(); + } +}; /// Import the GPU Ops to ROCDL Patterns. #include "GPUToROCDL.cpp.inc" @@ -240,6 +272,8 @@ patterns.add(converter, /*addressSpace=*/4); } + patterns.add(converter); + populateOpPatterns(converter, patterns, "__ocml_fabs_f32", "__ocml_fabs_f64"); populateOpPatterns(converter, patterns, "__ocml_atan_f32", diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -6,7 +6,8 @@ // CHECK32-LABEL: func @gpu_index_ops() func.func @gpu_index_ops() -> (index, index, index, index, index, index, - index, index, index, index, index, index) { + index, index, index, index, index, index, + index) { // CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64 // CHECK: rocdl.workitem.id.x : i32 @@ -49,10 +50,17 @@ // CHECK: = llvm.sext %{{.*}} : i32 to i64 %gDimZ = gpu.grid_dim z + // CHECK: = rocdl.mbcnt.lo %{{.*}}, %{{.*}} : (i32, i32) -> i32 + // CHECK: = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32 + // CHECK: = llvm.sext %{{.*}} : i32 to i64 + %laneId = gpu.lane_id + func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ, - %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ + %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ, + %laneId : index, index, index, index, index, index, - index, index, index, index, index, index + index, index, index, index, index, index, + index } } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -56,6 +56,16 @@ llvm.return } +llvm.func @rocdl.lane_id() -> i32 { + // CHECK: call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) + // CHECK-NEXT: call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 %{{.*}}) + %0 = llvm.mlir.constant(-1 : i32) : i32 + %1 = llvm.mlir.constant(0 : i32) : i32 + %2 = rocdl.mbcnt.lo %0, %1 : (i32, i32) -> i32 + %3 = rocdl.mbcnt.hi %0, %2 : (i32, i32) -> i32 + llvm.return %3 : i32 +} + llvm.func @rocdl.barrier() { // CHECK: fence syncscope("workgroup") release // CHECK-NEXT: call void @llvm.amdgcn.s.barrier()