diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td --- a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td @@ -164,4 +164,33 @@ }]; } +//===----------------------------------------------------------------------===// +// RewriteCopyAsTmaOp +//===----------------------------------------------------------------------===// + +def RewriteCopyAsTmaOp : + Op { + let description = [{ + Rewrite a copy operation on memref to tma operations that transit through + shared memory. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + }]; +} + #endif // NVGPU_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" @@ -21,19 +22,12 @@ #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/TypeRange.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::linalg; using namespace mlir::nvgpu; +using namespace mlir::NVVM; using namespace mlir::transform; #define DEBUG_TYPE "nvgpu-transforms" @@ -517,7 +511,7 @@ /// Build a list of memref.load operations indexed at `(row, col)` indices /// that make sense for a particular MMA instruction and specified via the /// IndexCalculator callback. - SmallVector buildMemrefLoads(OpBuilder &b, Location loc, + SmallVector buildMemRefLoads(OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn); @@ -527,7 +521,7 @@ /// data that makes sense for the particular MMA operation. /// The `vectorShape` matches existing NVGPU dialect op specification but /// could also be flattened in the future if needed for simplification. - Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc, + Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape); @@ -535,7 +529,7 @@ /// Build a list of memref.store operations indexed at `(row, col)` indices /// that make sense for a particular MMA instruction and specified via the /// IndexCalculator callback. - SmallVector buildMemrefStores(OpBuilder &b, Location loc, + SmallVector buildMemRefStores(OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn); @@ -546,7 +540,7 @@ /// data that makes sense for the particular MMA operation. /// The `vectorShape` matches existing NVGPU dialect op specification but /// could also be flattened in the future if needed for simplification. - SmallVector buildMmaSyncMemrefStoreOperand( + SmallVector buildMmaSyncMemRefStoreOperand( OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape); @@ -573,7 +567,7 @@ } } -SmallVector MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc, +SmallVector MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn) { @@ -591,10 +585,10 @@ return res; } -Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand( +Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape) { - auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn); + auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn); Type elementType = getElementTypeOrSelf(memref.getType()); auto vt = VectorType::get(vectorShape, elementType); @@ -614,7 +608,7 @@ } SmallVector -MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc, +MmaSyncBuilder::buildMemRefStores(OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn) { auto aff = [&](AffineExpr e) { @@ -632,7 +626,7 @@ return res; } -SmallVector MmaSyncBuilder::buildMmaSyncMemrefStoreOperand( +SmallVector MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, Value memref, IndexCalculator indexFn, ArrayRef vectorShape) { SmallVector toStore; @@ -647,7 +641,7 @@ [&](Value v, int64_t linearIdx, ArrayRef indices) { toStore.push_back(v); }); - return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn); + return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn); } static std::tuple, SmallVector, @@ -690,22 +684,22 @@ } FailureOr MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { - Value lhsMemref = linalgOp.getDpsInputOperand(0)->get(); - Value rhsMemref = linalgOp.getDpsInputOperand(1)->get(); - Value resMemref = linalgOp.getDpsInitOperand(0)->get(); - assert(lhsMemref.getType().cast().getRank() == 2 && + Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get(); + Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get(); + Value resMemRef = linalgOp.getDpsInitOperand(0)->get(); + assert(lhsMemRef.getType().cast().getRank() == 2 && "expected lhs to be a 2D memref"); - assert(rhsMemref.getType().cast().getRank() == 2 && + assert(rhsMemRef.getType().cast().getRank() == 2 && "expected rhs to be a 2D memref"); - assert(resMemref.getType().cast().getRank() == 2 && + assert(resMemRef.getType().cast().getRank() == 2 && "expected res to be a 2D memref"); - int64_t m = cast(lhsMemref.getType()).getShape()[0]; - int64_t n = cast(rhsMemref.getType()).getShape()[1]; - int64_t k = cast(lhsMemref.getType()).getShape()[1]; - Type lhsType = getElementTypeOrSelf(lhsMemref.getType()); - Type rhsType = getElementTypeOrSelf(rhsMemref.getType()); - Type resType = getElementTypeOrSelf(resMemref.getType()); + int64_t m = cast(lhsMemRef.getType()).getShape()[0]; + int64_t n = cast(rhsMemRef.getType()).getShape()[1]; + int64_t k = cast(lhsMemRef.getType()).getShape()[1]; + Type lhsType = getElementTypeOrSelf(lhsMemRef.getType()); + Type rhsType = getElementTypeOrSelf(rhsMemRef.getType()); + Type resType = getElementTypeOrSelf(resMemRef.getType()); FailureOr maybeInfo = getIndexCalculators({m, n, k}, {lhsType, rhsType, resType}); @@ -715,15 +709,15 @@ MmaSyncInfo info = *maybeInfo; auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; auto [lhsShape, rhsShape, resShape] = info.vectorShapes; - Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref, + Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef, lhsIndexFn, lhsShape); - Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref, + Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef, rhsIndexFn, rhsShape); - Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref, + Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef, resIndexFn, resShape); res = b.create(loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled); - buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn, + buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn, resShape); return res.getDefiningOp(); } @@ -754,6 +748,318 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// Hopper builders. +//===----------------------------------------------------------------------===// + +template +struct ValueOfType { + ValueOfType(Value v) : value(v) { + if (!isa(value.getType())) { + llvm::errs() << "wrong type for value, got: " << value << "\n" + << " of type: " << value.getType() << "\n" + << " but expected type: " << llvm::getTypeName() << "\n"; + assert(false && "wrong type for value"); + } + } + operator T() { return cast(value.getType()); } + T getType() { return cast(value.getType()); } + operator Value() { return value; } + Value value; +}; + +/// Helper to create the base Hopper-specific operations that are reused in +/// various other places. +struct HopperBuilder { + HopperBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {} + + ValueOfType + buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); + + /// Create tma descriptor op to initiate transfer from global to shared + /// memory. This must be done before the launch op, on the host. + ValueOfType + buildGlobalMemRefDescriptor(ValueOfType memref, + gpu::LaunchOp launchOp); + + /// Create wgmma matrix descriptor in shared memory given a memref in global + /// memory. Need a better way to do this. Currently, numbers are constant + /// https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor + // https://gist.github.com/grypp/b3c0815feca87cc7337e4057dcd4876a#file-m64n128k16_f16n_f16n_f16n-txt-L181-L194 + /// %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, 16, 256, 0, 0 : + /// !shmemlhs + std::pair, ValueOfType> + buildSharedAllocAndDescriptor(ValueOfType memref); + + /// Build a tma load from global memory to shared memory using `barrier` to + /// synchronize. Return the number of bytes that will be transferred. + OpFoldResult + buildTmaAsyncLoad(ValueOfType globalDesc, + ValueOfType sharedMemref, + ValueOfType barrier, + SmallVectorImpl &loadOps); + void buildBarrierArriveTx(ValueOfType barrier, + ArrayRef sizes); + + /// If threadIdx.x == 0 does TMA request + wait, else just wait. + /// Return the operation that performs the transfer on thread0. + SmallVector buildPredicateLoadsOnThread0( + ArrayRef> globalDescriptors, + ArrayRef> sharedMemBuffers, + ValueOfType barrier); + + void buildTryWaitParity(ValueOfType barrier); + + OpBuilder &b; + Location loc; +}; + +SmallVector HopperBuilder::buildPredicateLoadsOnThread0( + ArrayRef> globalDescriptors, + ArrayRef> sharedMemBuffers, + ValueOfType barrier) { + SmallVector loadOps; + Value zero = b.create(loc, 0); + Value tidx = b.create(loc, gpu::Dimension::x); + Value cond = + b.create(loc, arith::CmpIPredicate::eq, tidx, zero); + // clang-format off + b.create( + /*location=*/loc, + /*conditional=*/cond, + /*thenBuilder=*/ + [&](OpBuilder &lb, Location loc) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(lb.getInsertionBlock(), lb.getInsertionPoint()); + SmallVector sizes; + sizes.reserve(globalDescriptors.size()); + for (auto [desc, shmem] : llvm::zip_equal( + globalDescriptors, sharedMemBuffers)) { + OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps); + sizes.push_back(sz); + } + // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load. + // This may or may not have perf implications. + buildBarrierArriveTx(barrier, sizes); + b.create(loc); + }, + /*elseBuilder=*/ + [&](OpBuilder &lb, Location loc) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(lb.getInsertionBlock(), lb.getInsertionPoint()); + // TODO: is this for no-thread divergence? + // Should we just yield the size and hoist? + buildBarrierArriveTx(barrier, {}); + b.create(loc); + }); + // clang-format on + return loadOps; +} + +ValueOfType +HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { + // TODO: workgroup address space is not the same constant as shared memory + // space and there is no such thing as a shared memory space in the gpu + // dialect. It is unclear what the globally correct thing to use is atm. + // auto sharedMemorySpace = b.getI64IntegerAttr( + // static_cast(gpu::GPUDialect::getWorkgroupAddressSpace())); + // auto sharedMemorySpace = gpu::GPUMemorySpaceMappingAttr::get( + // b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); + auto sharedMemorySpace = + b.getI64IntegerAttr(static_cast(kSharedMemorySpace)); + Value barrier = b.create( + loc, nvgpu::MBarrierType::get(b.getContext(), sharedMemorySpace)); + b.create( + loc, barrier, getValueOrCreateConstantIndexOp(b, loc, numThreads)); + b.create(loc); + return barrier; +} + +ValueOfType +HopperBuilder::buildGlobalMemRefDescriptor(ValueOfType memref, + gpu::LaunchOp launchOp) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(launchOp); + Value unrankedMemRef = b.create( + loc, + UnrankedMemRefType::get(memref.getType().getElementType(), + memref.getType().getMemorySpace()), + memref); + auto mixedSizes = memref::getMixedSizes(b, loc, memref); + auto sizes = getValueOrCreateConstantIndexOp(b, loc, mixedSizes); + // TODO: workgroup address space is not the same constant as shared memory + // space and there is no such thing as a shared memory space in the gpu + // dialect. It is unclear what the globally correct thing to use is atm. + // auto sharedMemorySpace = b.getI64IntegerAttr( + // static_cast(gpu::GPUDialect::getWorkgroupAddressSpace())); + // auto sharedMemorySpace = gpu::GPUMemorySpaceMappingAttr::get( + // b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); + auto sharedMemorySpace = + b.getI64IntegerAttr(static_cast(kSharedMemorySpace)); + Value desc = b.create( + loc, + nvgpu::TensorMapDescriptorType::get( + b.getContext(), + MemRefType::Builder(memref).setMemorySpace(sharedMemorySpace), + TensorMapSwizzleKind::SWIZZLE_NONE, + TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, + TensorMapInterleaveKind::INTERLEAVE_NONE), + unrankedMemRef, sizes); + return desc; +} + +std::pair, ValueOfType> +HopperBuilder::buildSharedAllocAndDescriptor(ValueOfType memref) { + // TODO: workgroup address space is not the same constant as shared memory + // space and there is no such thing as a shared memory space in the gpu + // dialect. It is unclear what the globally correct thing to use is atm. + // auto sharedMemorySpace = b.getI64IntegerAttr( + // static_cast(gpu::GPUDialect::getWorkgroupAddressSpace())); + // auto sharedMemorySpace = gpu::GPUMemorySpaceMappingAttr::get( + // b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); + auto sharedMemorySpace = + b.getI64IntegerAttr(static_cast(kSharedMemorySpace)); + Value shmem = b.create( + loc, + MemRefType::Builder(memref.getType()).setMemorySpace(sharedMemorySpace)); + return std::make_pair, ValueOfType>( + shmem, shmem); +} + +OpFoldResult HopperBuilder::buildTmaAsyncLoad( + ValueOfType globalDesc, + ValueOfType sharedMemref, + ValueOfType barrier, + SmallVectorImpl &loadOps) { + MLIRContext *ctx = b.getContext(); + Value zero = b.create(loc, 0); + Operation *loadOp = b.create( + loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}); + loadOps.push_back(loadOp); + auto mixedSizes = memref::getMixedSizes(b, loc, sharedMemref); + SmallVector symbols(mixedSizes.size()); + bindSymbolsList(ctx, llvm::MutableArrayRef{symbols}); + AffineExpr prodExprInBytes = + computeProduct(ctx, symbols) * + (sharedMemref.getType().getElementTypeBitWidth() / 8); + auto res = affine::makeComposedFoldedAffineApply(b, loc, prodExprInBytes, + mixedSizes); + return res; +} + +void HopperBuilder::buildBarrierArriveTx( + ValueOfType barrier, + ArrayRef mixedSizes) { + MLIRContext *ctx = b.getContext(); + Value sizeVal; + if (!mixedSizes.empty()) { + SmallVector symbols(mixedSizes.size()); + bindSymbolsList(ctx, llvm::MutableArrayRef{symbols}); + AffineExpr sumExpr = computeSum(ctx, symbols); + OpFoldResult size = + affine::makeComposedFoldedAffineApply(b, loc, sumExpr, mixedSizes); + sizeVal = getValueOrCreateConstantIndexOp(b, loc, size); + } else { + sizeVal = b.create(loc, 0); + } + b.create(loc, barrier, sizeVal); +} + +void HopperBuilder::buildTryWaitParity( + ValueOfType barrier) { + Value c0 = b.create(loc, 0); + Value c10M = b.create(loc, 10000000); + b.create(loc, barrier, c0, c10M); +} + +//===----------------------------------------------------------------------===// +// RewriteCopyAsTmaOp +//===----------------------------------------------------------------------===// + +/// Helper to create the tma operations corresponding to `linalg::CopyOp`. +struct CopyBuilder : public HopperBuilder { + CopyBuilder(OpBuilder &b, Location loc) : HopperBuilder(b, loc) {} + + FailureOr> build(ArrayRef copyOps); +}; + +FailureOr> +CopyBuilder::build(ArrayRef copyOps) { + MLIRContext *ctx = b.getContext(); + if (copyOps.empty()) + return SmallVector(); + + auto launchOp = copyOps.front()->getParentOfType(); + assert(launchOp && "expected launch op"); + + // 1. Init a barrier object in shared memory. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(copyOps.front()); + Value barrier = buildAndInitBarrierInSharedMemory( + /*numThreads=*/getAsIndexOpFoldResult(ctx, 128)); + + SmallVector> shmems; + SmallVector> globalDescs; + for (Operation *op : copyOps) { + auto copyOp = cast(op); + Value inMemRef = copyOp.getDpsInputOperand(0)->get(); + auto inMemRefType = cast(inMemRef.getType()); + assert(inMemRefType.getRank() == 2 && "expected in to be a 2D memref"); + + // 2. Build global memory descriptor. + ValueOfType globalDesc = + buildGlobalMemRefDescriptor(inMemRef, launchOp); + globalDescs.push_back(globalDesc); + + // 3. Shared memory and descriptor for the tmp array. + ValueOfType shmem = copyOp.getDpsInitOperand(0)->get(); + shmems.push_back(shmem); + } + + // 4. Load in from global memory to shared memory using tma. + OpBuilder::InsertionGuard g2(b); + b.setInsertionPoint(copyOps.front()); + SmallVector results = + buildPredicateLoadsOnThread0(globalDescs, shmems, barrier); + + // 5. Spin-loop until data is ready. + buildTryWaitParity(barrier); + + return results; +} + +DiagnosedSilenceableFailure +transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto payloadOps = state.getPayloadOps(getTarget()); + gpu::LaunchOp commonLaunchOp; + if (llvm::any_of(payloadOps, [&commonLaunchOp](Operation *op) { + if (!commonLaunchOp) + commonLaunchOp = op->getParentOfType(); + return !op->getParentOfType() || + commonLaunchOp != op->getParentOfType() || + !isa(op); + })) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "target op must be a linalg::CopyOp nested under a gpu.LaunchOp to " + "be rewritten because the tma descriptors need to be created on the " + "host"; + return diag; + } + + // TODO: more robust detection of copy, with transposes etc. + if (!succeeded( + CopyBuilder(rewriter, getLoc()).build(llvm::to_vector(payloadOps)))) { + return emitSilenceableError("some copy op failed to lower to tma"); + } + + for (Operation *target : state.getPayloadOps(getTarget())) + rewriter.eraseOp(target); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// @@ -767,6 +1073,7 @@ declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); + declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-opt %s \ +// RUN: -test-transform-dialect-interpreter \ +// RUN: -test-transform-dialect-erase-schedule \ +// RUN: -convert-nvgpu-to-nvvm -gpu-kernel-outlining \ +// RUN: -convert-scf-to-cf -convert-nvvm-to-llvm \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -convert-index-to-llvm=index-bitwidth=32 \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -canonicalize \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_90 features=+ptx80 dump-ptx}))' \ +// RUN: 2&>1 | FileCheck %s --check-prefixes=CHECK-PTX + +// CHECK-PTX: mbarrier.init.shared {{.*}} !llvm.ptr<3>, i32 +/// If branch +// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes +// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes +// CHECK-PTX: mbarrier.arrive.expect_tx.shared +/// Else branch +// CHECK-PTX: mbarrier.arrive.expect_tx.shared +// CHECK-PTX: mbarrier.try_wait.parity.shared + +// TODO: GPU layering does not currently work end-to-end. Activate the following +// when fixed. +// R-UN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \ +// R-UN: -gpu-to-llvm \ +// R-UN: -convert-func-to-llvm \ +// R-UN: -cse \ +// R-UN: -canonicalize \ +// R-UN: -reconcile-unrealized-casts \ +// R-UN: | mlir-cpu-runner \ +// R-UN: --shared-libs=%mlir_cuda_runtime \ +// R-UN: --shared-libs=%mlir_runner_utils \ +// R-UN: --entry-point-result=void \ +// R-UN: | FileCheck %s + +// C-HECK: [GPU] TMA BEFORE lhs[45][7] 0.000000 +// C-HECK: [GPU] TMA BEFORE rhs[7][0] 0.000000 +// C-HECK: [GPU] TMA LOADED lhs[45][7] 7.000000 +// C-HECK: [GPU] TMA LOADED rhs[7][0] 3.000000 + + +module @mymod { + memref.global "private" @bufferLhsGlobal : memref<64x8xf32, 3> + memref.global "private" @bufferRhsGlobal : memref<8x128xf32, 3> + func.func @main() { + %c10000000 = arith.constant 10000000 : index + %c6144 = arith.constant 6144 : index + %c45 = arith.constant 45 : index + %c7 = arith.constant 7 : index + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 3.000000e+00 : f32 + %alloc = memref.alloc() : memref<64x8xf32> + %alloc_0 = memref.alloc() : memref<8x128xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c128 step %c1 { + memref.store %cst, %alloc_0[%arg0, %arg1] : memref<8x128xf32> + } + } + scf.for %arg0 = %c0 to %c64 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %5 = arith.index_cast %arg1 : index to i64 + %6 = arith.uitofp %5 : i64 to f32 + memref.store %6, %alloc[%arg0, %arg1] : memref<64x8xf32> + } + } + %0 = gpu.wait async + %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32> + %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32> + %1 = gpu.memcpy async [%0] %memref, %alloc : memref<64x8xf32>, memref<64x8xf32> + %2 = gpu.memcpy async [%0] %memref_1, %alloc_0 : memref<8x128xf32>, memref<8x128xf32> + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) { + %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, 3> + %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, 3> + linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, 3>) + linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, 3>) + + %6 = gpu.thread_id x + %10 = arith.cmpi eq, %6, %c0 : index + scf.if %10 { + %11 = memref.load %out[%c45, %c7] : memref<64x8xf32, 3> + %12 = memref.load %out_1[%c7, %c0] : memref<8x128xf32, 3> + gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32 + gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32 + } + gpu.terminator + } + + return + } +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %copy = transform.structured.match ops{["linalg.copy"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.nvgpu.rewrite_copy_as_tma %copy + : (!transform.any_op) -> () +}