diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1392,4 +1392,31 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// NVVM TMA Ops +//===----------------------------------------------------------------------===// + +def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", [DeclareOpInterfaceMethods]>, + Arguments<(ins LLVM_i64ptr_shared:$dstMem, + LLVM_i64ptr_any:$tmaDescriptor, + LLVM_i64ptr_shared:$mbar, + Variadic:$coordinates)> { + let assemblyFormat = "$dstMem `,` $tmaDescriptor `,` $mbar `,` `box` `[`$coordinates `]` attr-dict `:` type(operands)"; + let extraClassDefinition = [{ + std::string $cppClass::getPtx() { + int dim = getCoordinates().size(); + std::string ptx = "cp.async.bulk.tensor."; + ptx += std::to_string(dim) + "d."; + ptx += "shared::cluster.global.mbarrier::complete_tx::bytes"; + if(dim == 1) ptx += " [%0], [%1, {%3}\], [%2];"; + if(dim == 2) ptx += " [%0], [%1, {%3, %4}\], [%2];"; + if(dim == 3) ptx += " [%0], [%1, {%3, %4, %5}\], [%2];"; + if(dim == 4) ptx += " [%0], [%1, {%3, %4, %5, %6}\], [%2];"; + if(dim == 5) ptx += " [%0], [%1, {%3, %4, %5, %6, %7}\], [%2];"; + return ptx; + } + }]; + let hasVerifier = 1; +} + #endif // NVVMIR_OPS diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Attributes.h" @@ -32,6 +33,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/SourceMgr.h" #include +#include using namespace mlir; using namespace NVVM; @@ -67,6 +69,12 @@ void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } +LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { + if (getCoordinates().size() > 5) + return emitError("Maximum 5 coordinates and dimension is supported."); + return success(); +} + LogicalResult CpAsyncOp::verify() { if (getModifier() != LoadCacheModifierKind::CG && getModifier() != LoadCacheModifierKind::CA) diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -45,3 +45,38 @@ nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 return } + +// CHECK-LABEL : @tma_load_1d +func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32) { + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3}], [$2];", "l,r,r,r" + nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32 + return +} + +// CHECK-LABEL : @tma_load_2d +func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32) { + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4}], [$2];", "l,r,r,r,r" + nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32 + return +} + +// CHECK-LABEL : @tma_load_3d +func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32) { + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5}], [$2];", "l,r,r,r,r,r" + nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32 + return +} + +// CHECK-LABEL : @tma_load_4d +func.func @tma_load_4d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32) { + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6}], [$2];", "l,r,r,r,r,r,r" + nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32 + return +} + +// CHECK-LABEL : @tma_load_5d +func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32) { + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$3, $4, $5, $6, $7}], [$2];", "l,r,r,r,r,r,r,r" + nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32 + return +}