diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td --- a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td @@ -164,4 +164,33 @@ }]; } +//===----------------------------------------------------------------------===// +// RewriteCopyAsTmaOp +//===----------------------------------------------------------------------===// + +def RewriteCopyAsTmaOp : + Op { + let description = [{ + Rewrite a copy operation on memref to tma operations that transit through + shared memory. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + }]; +} + #endif // NVGPU_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" @@ -20,20 +21,17 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/TypeRange.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/IR/Value.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::linalg; using namespace mlir::nvgpu; +using namespace mlir::NVVM; using namespace mlir::transform; #define DEBUG_TYPE "nvgpu-transforms" @@ -517,7 +515,7 @@ /// Build a list of memref.load operations indexed at `(row, col)` indices /// that make sense for a particular MMA instruction and specified via the /// IndexCalculator callback. - SmallVector buildMemrefLoads(OpBuilder &b, Location loc, + SmallVector buildMemRefLoads(OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn); @@ -527,7 +525,7 @@ /// data that makes sense for the particular MMA operation. /// The `vectorShape` matches existing NVGPU dialect op specification but /// could also be flattened in the future if needed for simplification. - Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc, + Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape); @@ -535,7 +533,7 @@ /// Build a list of memref.store operations indexed at `(row, col)` indices /// that make sense for a particular MMA instruction and specified via the /// IndexCalculator callback. - SmallVector buildMemrefStores(OpBuilder &b, Location loc, + SmallVector buildMemRefStores(OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn); @@ -546,7 +544,7 @@ /// data that makes sense for the particular MMA operation. /// The `vectorShape` matches existing NVGPU dialect op specification but /// could also be flattened in the future if needed for simplification. - SmallVector buildMmaSyncMemrefStoreOperand( + SmallVector buildMmaSyncMemRefStoreOperand( OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape); @@ -573,7 +571,7 @@ } } -SmallVector MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc, +SmallVector MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn) { @@ -591,10 +589,10 @@ return res; } -Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand( +Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape) { - auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn); + auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn); Type elementType = getElementTypeOrSelf(memref.getType()); auto vt = VectorType::get(vectorShape, elementType); @@ -614,7 +612,7 @@ } SmallVector -MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc, +MmaSyncBuilder::buildMemRefStores(OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn) { auto aff = [&](AffineExpr e) { @@ -632,7 +630,7 @@ return res; } -SmallVector MmaSyncBuilder::buildMmaSyncMemrefStoreOperand( +SmallVector MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape) { SmallVector toStore; @@ -647,7 +645,7 @@ [&](Value v, int64_t linearIdx, ArrayRef indices) { toStore.push_back(v); }); - return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn); + return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn); } static std::tuple, SmallVector, @@ -690,22 +688,22 @@ } FailureOr MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { - Value lhsMemref = linalgOp.getDpsInputOperand(0)->get(); - Value rhsMemref = linalgOp.getDpsInputOperand(1)->get(); - Value resMemref = linalgOp.getDpsInitOperand(0)->get(); - assert(lhsMemref.getType().cast().getRank() == 2 && + Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get(); + Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get(); + Value resMemRef = linalgOp.getDpsInitOperand(0)->get(); + assert(lhsMemRef.getType().cast().getRank() == 2 && "expected lhs to be a 2D memref"); - assert(rhsMemref.getType().cast().getRank() == 2 && + assert(rhsMemRef.getType().cast().getRank() == 2 && "expected rhs to be a 2D memref"); - assert(resMemref.getType().cast().getRank() == 2 && + assert(resMemRef.getType().cast().getRank() == 2 && "expected res to be a 2D memref"); - int64_t m = cast(lhsMemref.getType()).getShape()[0]; - int64_t n = cast(rhsMemref.getType()).getShape()[1]; - int64_t k = cast(lhsMemref.getType()).getShape()[1]; - Type lhsType = getElementTypeOrSelf(lhsMemref.getType()); - Type rhsType = getElementTypeOrSelf(rhsMemref.getType()); - Type resType = getElementTypeOrSelf(resMemref.getType()); + int64_t m = cast(lhsMemRef.getType()).getShape()[0]; + int64_t n = cast(rhsMemRef.getType()).getShape()[1]; + int64_t k = cast(lhsMemRef.getType()).getShape()[1]; + Type lhsType = getElementTypeOrSelf(lhsMemRef.getType()); + Type rhsType = getElementTypeOrSelf(rhsMemRef.getType()); + Type resType = getElementTypeOrSelf(resMemRef.getType()); FailureOr maybeInfo = getIndexCalculators({m, n, k}, {lhsType, rhsType, resType}); @@ -715,15 +713,15 @@ MmaSyncInfo info = *maybeInfo; auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; auto [lhsShape, rhsShape, resShape] = info.vectorShapes; - Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref, + Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef, lhsIndexFn, lhsShape); - Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref, + Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef, rhsIndexFn, rhsShape); - Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref, + Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef, resIndexFn, resShape); res = b.create(loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled); - buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn, + buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn, resShape); return res.getDefiningOp(); } @@ -754,6 +752,284 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// Hopper builders. +//===----------------------------------------------------------------------===// + +/// Helper to create the base Hopper-specific operations that are reused in +/// various other places. +struct HopperBuilder { + HopperBuilder(RewriterBase &rewriter, Location loc) + : rewriter(rewriter), loc(loc) {} + + TypedValue + buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); + + /// Create tma descriptor op to initiate transfer from global to shared + /// memory. This must be done before the launch op, on the host. + TypedValue + buildGlobalMemRefDescriptor(TypedValue memref, + gpu::LaunchOp launchOp); + + /// Build a tma load from global memory to shared memory using `barrier` to + /// synchronize. Return the number of bytes that will be transferred. + OpFoldResult + buildTmaAsyncLoad(TypedValue globalDesc, + TypedValue sharedMemref, + TypedValue barrier, + SmallVectorImpl &loadOps); + void buildBarrierArriveTx(TypedValue barrier, + ArrayRef sizes); + + /// If threadIdx.x == 0 does TMA request + wait, else just wait. + /// Return the operation that performs the transfer on thread0. + // TODO: In the future, don't hardcode to thread 0 but elect a leader. + SmallVector buildPredicateLoadsOnThread0( + ArrayRef> globalDescriptors, + ArrayRef> sharedMemBuffers, + TypedValue barrier); + + void buildTryWaitParity(TypedValue barrier); + + RewriterBase &rewriter; + Location loc; +}; + +SmallVector HopperBuilder::buildPredicateLoadsOnThread0( + ArrayRef> globalDescriptors, + ArrayRef> sharedMemBuffers, + TypedValue barrier) { + SmallVector loadOps; + Value zero = rewriter.create(loc, 0); + Value tidx = rewriter.create(loc, gpu::Dimension::x); + Value cond = + rewriter.create(loc, arith::CmpIPredicate::eq, tidx, zero); + // clang-format off + rewriter.create( + /*location=*/loc, + /*conditional=*/cond, + /*thenBuilder=*/ + [&](OpBuilder &lb, Location loc) { + SmallVector sizes; + sizes.reserve(globalDescriptors.size()); + for (auto [desc, shmem] : llvm::zip_equal( + globalDescriptors, sharedMemBuffers)) { + OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps); + sizes.push_back(sz); + } + // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load. + // This may or may not have perf implications. + buildBarrierArriveTx(barrier, sizes); + rewriter.create(loc); + }, + /*elseBuilder=*/ + [&](OpBuilder &lb, Location loc) { + // TODO: is this for no-thread divergence? + // Should we just yield the size and hoist? + buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0)); + rewriter.create(loc); + }); + // clang-format on + return loadOps; +} + +static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { + return gpu::AddressSpaceAttr::get( + b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); + // return b.getI64IntegerAttr(static_cast(kSharedMemorySpace)); +} + +TypedValue +HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { + auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); + Value barrier = rewriter.create( + loc, nvgpu::MBarrierType::get(rewriter.getContext(), sharedMemorySpace)); + rewriter.create( + loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads)); + rewriter.create(loc); + return cast>(barrier); +} + +TypedValue +HopperBuilder::buildGlobalMemRefDescriptor(TypedValue memref, + gpu::LaunchOp launchOp) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(launchOp); + Value unrankedMemRef = rewriter.create( + loc, + UnrankedMemRefType::get(memref.getType().getElementType(), + memref.getType().getMemorySpace()), + memref); + SmallVector mixedSizes = + memref::getMixedSizes(rewriter, loc, memref); + SmallVector sizes = + getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes); + + auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); + Value desc = rewriter.create( + loc, + nvgpu::TensorMapDescriptorType::get( + rewriter.getContext(), + MemRefType::Builder(memref.getType()) + .setMemorySpace(sharedMemorySpace), + TensorMapSwizzleKind::SWIZZLE_NONE, + TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, + TensorMapInterleaveKind::INTERLEAVE_NONE), + unrankedMemRef, sizes); + return cast>(desc); +} + +OpFoldResult HopperBuilder::buildTmaAsyncLoad( + TypedValue globalDesc, + TypedValue sharedMemref, + TypedValue barrier, + SmallVectorImpl &loadOps) { + MLIRContext *ctx = rewriter.getContext(); + Value zero = rewriter.create(loc, 0); + Operation *loadOp = rewriter.create( + loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}); + loadOps.push_back(loadOp); + auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref); + SmallVector symbols(mixedSizes.size()); + bindSymbolsList(ctx, llvm::MutableArrayRef{symbols}); + AffineExpr prodExprInBytes = + computeProduct(ctx, symbols) * + (sharedMemref.getType().getElementTypeBitWidth() / 8); + auto res = affine::makeComposedFoldedAffineApply(rewriter, loc, + prodExprInBytes, mixedSizes); + return res; +} + +void HopperBuilder::buildBarrierArriveTx( + TypedValue barrier, + ArrayRef mixedSizes) { + assert(!mixedSizes.empty() && "expecte non-empty sizes"); + MLIRContext *ctx = rewriter.getContext(); + SmallVector symbols(mixedSizes.size()); + bindSymbolsList(ctx, llvm::MutableArrayRef{symbols}); + AffineExpr sumExpr = computeSum(ctx, symbols); + OpFoldResult size = + affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes); + Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size); + rewriter.create(loc, barrier, sizeVal); +} + +void HopperBuilder::buildTryWaitParity( + TypedValue barrier) { + Value parity = rewriter.create(loc, 0); + // 10M is an arbitrary, not too small or too big number to specify the number + // of ticks before retry. + // TODO: hoist this in a default dialect constant. + Value ticksBeforeRetry = + rewriter.create(loc, 10000000); + rewriter.create(loc, barrier, parity, + ticksBeforeRetry); +} + +//===----------------------------------------------------------------------===// +// RewriteCopyAsTmaOp +//===----------------------------------------------------------------------===// + +/// Helper to create the tma operations corresponding to `linalg::CopyOp`. +struct CopyBuilder : public HopperBuilder { + CopyBuilder(RewriterBase &rewriter, Location loc) + : HopperBuilder(rewriter, loc) {} + + SmallVector rewrite(ArrayRef copyOps); +}; + +SmallVector CopyBuilder::rewrite(ArrayRef copyOps) { + MLIRContext *ctx = rewriter.getContext(); + if (copyOps.empty()) + return SmallVector(); + + auto launchOp = copyOps.front()->getParentOfType(); + assert(launchOp && "expected launch op"); + + // 1. Init a barrier object in shared memory. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(copyOps.front()); + AffineExpr bx, by, bz; + bindSymbols(ctx, bx, by, bz); + AffineExpr prod = computeProduct(ctx, ArrayRef{bx, by, bz}); + OpFoldResult numThreads = affine::makeComposedFoldedAffineApply( + rewriter, loc, prod, + ArrayRef{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(), + launchOp.getBlockSizeZ()}); + + TypedValue barrier = + buildAndInitBarrierInSharedMemory(numThreads); + + SmallVector> shmems; + SmallVector> globalDescs; + for (Operation *op : copyOps) { + auto copyOp = cast(op); + auto inMemRef = + cast>(copyOp.getDpsInputOperand(0)->get()); + MemRefType inMemRefType = inMemRef.getType(); + assert(inMemRefType.getRank() == 2 && "expected in to be a 2D memref"); + + // 2. Build global memory descriptor. + TypedValue globalDesc = + buildGlobalMemRefDescriptor(inMemRef, launchOp); + globalDescs.push_back(globalDesc); + + // 3. Shared memory and descriptor for the tmp array. + auto shmem = + cast>(copyOp.getDpsInitOperand(0)->get()); + shmems.push_back(shmem); + } + + // 4. Load in from global memory to shared memory using tma. + OpBuilder::InsertionGuard g2(rewriter); + rewriter.setInsertionPoint(copyOps.front()); + SmallVector results = + buildPredicateLoadsOnThread0(globalDescs, shmems, barrier); + + // 5. Spin-loop until data is ready. + buildTryWaitParity(barrier); + + // 6. Erase the ops that have now been rewritten. + for (Operation *op : copyOps) + rewriter.eraseOp(op); + + return results; +} + +DiagnosedSilenceableFailure +transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto payloadOps = state.getPayloadOps(getTarget()); + gpu::LaunchOp commonLaunchOp; + Operation *firstOp, *failingOp; + if (llvm::any_of(payloadOps, [&](Operation *op) { + if (!commonLaunchOp) { + commonLaunchOp = op->getParentOfType(); + firstOp = op; + } + auto fail = !op->getParentOfType() || + commonLaunchOp != op->getParentOfType() || + !isa(op); + if (fail) + failingOp = op; + return fail; + })) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "target ops must be linalg::CopyOp nested under a common " + "gpu.LaunchOp to be rewritten because the tma descriptors need to " + "be created on the host.\nBut got: " + << *firstOp << "\nand " << *failingOp; + return diag; + } + + // TODO: more robust detection of copy, with transposes etc. + CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps)); + + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// @@ -767,6 +1043,7 @@ declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); + declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -161,13 +161,13 @@ if (basis.empty()) return getAffineConstantExpr(0, ctx); return std::accumulate(basis.begin(), basis.end(), - getAffineConstantExpr(1, ctx), + getAffineConstantExpr(0, ctx), std::plus()); } AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef basis) { if (basis.empty()) - return getAffineConstantExpr(0, ctx); + return getAffineConstantExpr(1, ctx); return std::accumulate(basis.begin(), basis.end(), getAffineConstantExpr(1, ctx), std::multiplies()); diff --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir @@ -0,0 +1,84 @@ +// RUN: mlir-opt %s \ +// RUN: -test-transform-dialect-interpreter \ +// RUN: -test-transform-dialect-erase-schedule \ +// RUN: | FileCheck %s + +memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space> +memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space> + +// CHECK-LABEL: func.func @main() +func.func @main() { + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + + %0 = gpu.wait async + %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32> + %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32> + + // CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32> + // CHECK: %[[c64:.*]] = arith.constant 64 : index + // CHECK: %[[c8:.*]] = arith.constant 8 : index + // CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]] + // CHECK-SAME: : memref<*xf32> -> >, swizzle = none, l2promo = none, oob = zero, interleave = none> + // CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32> + // CHECK: %[[c8_2:.*]] = arith.constant 8 : index + // CHECK: %[[c128_2:.*]] = arith.constant 128 : index + // CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]] + // CHECK-SAME: : memref<*xf32> -> >, swizzle = none, l2promo = none, oob = zero, interleave = none> + // CHECK: gpu.launch + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) { + // CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space> + // CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space> + %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space> + %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space> + + // CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> + // CHECK: nvgpu.mbarrier.init %[[B]], %{{.*}} : + // CHECK: gpu.barrier + // + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[TIDX:.*]] = gpu.thread_id x + // CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TIDX]], %[[c0]] : index + // + // CHECK: scf.if %[[CMP]] { + // + // CHECK: %[[c0_7:.*]] = arith.constant 0 : index + // CHECK: nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]] to %[[G1]] + // CHECK-SAME: : >, + // CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, + // CHECK-SAME: -> memref<64x8xf32, #gpu.address_space> + // + // CHECK: %[[c0_8:.*]] = arith.constant 0 : index + // CHECK: nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]] to %[[G2]] + // CHECK-SAME: : >, + // CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, + // CHECK-SAME: -> memref<8x128xf32, #gpu.address_space> + // + // CHECK: %[[c6144:.*]] = arith.constant 6144 : index + // CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]], %[[c6144]] : + // CHECK: } else { + // CHECK: %[[c0_7:.*]] = arith.constant 0 : index + // CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]], %[[c0_7]] : + // CHECK: } + // + // CHECK: %[[c0_6:.*]] = arith.constant 0 : index + // CHECK: %[[c10000000:.*]] = arith.constant 10000000 : index + // CHECK: nvgpu.mbarrier.try_wait.parity %[[B]], %[[c0_6]], %[[c10000000]] : + + /// Both copies are matched and end up in the same async group. + linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space>) + linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space>) + + gpu.terminator + } + + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %copy = transform.structured.match ops{["linalg.copy"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.nvgpu.rewrite_copy_as_tma %copy : (!transform.any_op) -> () +} diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-opt %s \ +// RUN: -test-transform-dialect-interpreter \ +// RUN: -test-transform-dialect-erase-schedule \ +// RUN: -convert-nvgpu-to-nvvm -gpu-kernel-outlining \ +// RUN: -convert-scf-to-cf -convert-nvvm-to-llvm \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -convert-index-to-llvm=index-bitwidth=32 \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -canonicalize \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_90 features=+ptx80 dump-ptx}))' \ +// RUN: 2&>1 | FileCheck %s --check-prefixes=CHECK-PTX + +// CHECK-PTX: mbarrier.init.shared {{.*}} !llvm.ptr<3>, i32 +/// If branch +// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes +// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes +// CHECK-PTX: mbarrier.arrive.expect_tx.shared +/// Else branch +// CHECK-PTX: mbarrier.arrive.expect_tx.shared +// CHECK-PTX: mbarrier.try_wait.parity.shared + +// TODO: GPU layering does not currently work end-to-end. Activate the following +// when fixed. +// R-UN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \ +// R-UN: -gpu-to-llvm \ +// R-UN: -convert-func-to-llvm \ +// R-UN: -cse \ +// R-UN: -canonicalize \ +// R-UN: -reconcile-unrealized-casts \ +// R-UN: | mlir-cpu-runner \ +// R-UN: --shared-libs=%mlir_cuda_runtime \ +// R-UN: --shared-libs=%mlir_runner_utils \ +// R-UN: --entry-point-result=void \ +// R-UN: | FileCheck %s + +// C-HECK: [GPU] TMA BEFORE lhs[45][7] 0.000000 +// C-HECK: [GPU] TMA BEFORE rhs[7][0] 0.000000 +// C-HECK: [GPU] TMA LOADED lhs[45][7] 7.000000 +// C-HECK: [GPU] TMA LOADED rhs[7][0] 3.000000 + + +module @mymod { + memref.global "private" @bufferLhsGlobal : memref<64x8xf32, 3> + memref.global "private" @bufferRhsGlobal : memref<8x128xf32, 3> + func.func @main() { + %c10000000 = arith.constant 10000000 : index + %c6144 = arith.constant 6144 : index + %c45 = arith.constant 45 : index + %c7 = arith.constant 7 : index + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 3.000000e+00 : f32 + %alloc = memref.alloc() : memref<64x8xf32> + %alloc_0 = memref.alloc() : memref<8x128xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c128 step %c1 { + memref.store %cst, %alloc_0[%arg0, %arg1] : memref<8x128xf32> + } + } + scf.for %arg0 = %c0 to %c64 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %5 = arith.index_cast %arg1 : index to i64 + %6 = arith.uitofp %5 : i64 to f32 + memref.store %6, %alloc[%arg0, %arg1] : memref<64x8xf32> + } + } + %0 = gpu.wait async + %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32> + %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32> + %1 = gpu.memcpy async [%0] %memref, %alloc : memref<64x8xf32>, memref<64x8xf32> + %2 = gpu.memcpy async [%0] %memref_1, %alloc_0 : memref<8x128xf32>, memref<8x128xf32> + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) { + %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, 3> + %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, 3> + linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, 3>) + linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, 3>) + + %6 = gpu.thread_id x + %10 = arith.cmpi eq, %6, %c0 : index + scf.if %10 { + %11 = memref.load %out[%c45, %c7] : memref<64x8xf32, 3> + %12 = memref.load %out_1[%c7, %c0] : memref<8x128xf32, 3> + gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32 + gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32 + } + gpu.terminator + } + + return + } +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %copy = transform.structured.match ops{["linalg.copy"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.nvgpu.rewrite_copy_as_tma %copy + : (!transform.any_op) -> () +}