diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -77,6 +77,24 @@ let assemblyFormat = "attr-dict `:` type($res)"; } +//===----------------------------------------------------------------------===// +// Warp-level primitives + +def ROCDL_LaneIdOp : + ROCDL_Op<"lane_id">, + Results<(outs LLVM_Type:$res)> { + string llvmBuilder = [{ + auto i32type = llvm::Type::getInt32Ty(moduleTranslation.getLLVMContext()); + auto cst__1 = llvm::ConstantInt::get(i32type, -1); + auto cst_0 = llvm::ConstantInt::get(i32type, 0); + auto mbcnt_lo = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_mbcnt_lo, {cst__1, cst_0}); + $res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_mbcnt_hi, {cst__1, mbcnt_lo}); + }]; + let assemblyFormat = [{ + attr-dict `:` type($res) + }]; +} + //===----------------------------------------------------------------------===// // Thread index and Block index diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -60,6 +60,29 @@ } namespace { +struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *context = rewriter.getContext(); + Value newOp = rewriter.create(loc, rewriter.getI32Type()); + // Truncate or extend the result depending on the index bitwidth specified + // by the LLVMTypeConverter options. + const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); + if (indexBitwidth > 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } else if (indexBitwidth < 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } + rewriter.replaceOp(op, {newOp}); + return success(); + } +}; /// Import the GPU Ops to ROCDL Patterns. #include "GPUToROCDL.cpp.inc" @@ -240,6 +263,8 @@ patterns.add(converter, /*addressSpace=*/4); } + patterns.add(converter); + populateOpPatterns(converter, patterns, "__ocml_fabs_f32", "__ocml_fabs_f64"); populateOpPatterns(converter, patterns, "__ocml_atan_f32", diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -6,7 +6,8 @@ // CHECK32-LABEL: func @gpu_index_ops() func.func @gpu_index_ops() -> (index, index, index, index, index, index, - index, index, index, index, index, index) { + index, index, index, index, index, index, + index) { // CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64 // CHECK: rocdl.workitem.id.x : i32 @@ -49,10 +50,16 @@ // CHECK: = llvm.sext %{{.*}} : i32 to i64 %gDimZ = gpu.grid_dim z + // CHECK: = rocdl.lane_id : i32 + // CHECK: = llvm.sext %{{.*}} : i32 to i64 + %laneId = gpu.lane_id + func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ, - %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ + %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ, + %laneId : index, index, index, index, index, index, - index, index, index, index, index, index + index, index, index, index, index, index, + index } } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -56,6 +56,13 @@ llvm.return } +llvm.func @rocdl.lane_id() -> i32 { + // CHECK: call i32 @llvm.amdgcn.mbcnt.lo(i32 -1, i32 0) + // CHECK-NEXT: call i32 @llvm.amdgcn.mbcnt.hi(i32 -1, i32 %{{.*}}) + %0 = rocdl.lane_id : i32 + llvm.return %0 : i32 +} + llvm.func @rocdl.barrier() { // CHECK: fence syncscope("workgroup") release // CHECK-NEXT: call void @llvm.amdgcn.s.barrier()