diff --git a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt --- a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(IR) +add_subdirectory(TransformOps) set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name NVGPU) diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS NVGPUTransformOps.td) +mlir_tablegen(NVGPUTransformOps.h.inc -gen-op-decls) +mlir_tablegen(NVGPUTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRNVGPUTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h @@ -0,0 +1,43 @@ +//===- NVGPUTransformOps.h - NVGPU transform ops ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H +#define MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" + +namespace mlir { +namespace transform { +class TransformHandleTypeInterface; +} // namespace transform +} // namespace mlir + +namespace mlir { +class DialectRegistry; + +namespace linalg { +class LinalgOp; +} // namespace linalg + +namespace nvgpu { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace nvgpu +} // namespace mlir + +//===----------------------------------------------------------------------===// +// NVGPU Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h.inc" + +#endif // MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td @@ -0,0 +1,47 @@ +//===- NVGPUTransformOps.td - NVGPU transform ops ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NVGPU_TRANSFORM_OPS +#define NVGPU_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// RewriteMatmulAsMmaSyncOp +//===----------------------------------------------------------------------===// + +def RewriteMatmulAsMmaSyncOp : + Op { + let description = [{ + Rewrite a matmul operation on memref to an mma.sync operation on vectors. + + Memory copies with the required access patterns are automatically inserted. + Operations that do not have a 1-1 mapping to mma.sync operations are left + unchanged. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp linalgOp, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +#endif // NVGPU_TRANSFORM_OPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -55,6 +55,7 @@ #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -137,6 +138,7 @@ gpu::registerTransformDialectExtension(registry); linalg::registerTransformDialectExtension(registry); memref::registerTransformDialectExtension(registry); + nvgpu::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); tensor::registerTransformDialectExtension(registry); transform::registerPDLExtension(registry); diff --git a/mlir/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/CMakeLists.txt --- a/mlir/lib/Dialect/NVGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/NVGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Utils) +add_subdirectory(TransformOps) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_dialect_library(MLIRNVGPUTransformOps + NVGPUTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU/TransformOps + + DEPENDS + MLIRNVGPUTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRAffineDialect + MLIRArithDialect + MLIRIR + MLIRLinalgDialect + MLIRNVGPUDialect + MLIRParser + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRVectorTransforms + ) diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -0,0 +1,406 @@ +//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +using namespace mlir::linalg; +using namespace mlir::nvgpu; +using namespace mlir::transform; + +#define DEBUG_TYPE "nvgpu-transforms" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DBGSNL() (llvm::dbgs() << "\n") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +//===----------------------------------------------------------------------===// +// RewriteMatmulAsMmaSyncOp +//===----------------------------------------------------------------------===// + +/// Helper struct to encode a pair of row/column indexings in the form of +/// affine expressions. +struct RowColIndexing : private std::pair { + RowColIndexing(AffineExpr row, AffineExpr col) + : std::pair(row, col) {} + + AffineExpr row() const { return first; }; + AffineExpr col() const { return second; }; + + void print(llvm::raw_ostream &os) const { + os << "- indexing: " << first << ", " << second; + } +}; + +/// Helper struct to provide a simple mapping from matmul operations to the +/// corresponding mma.sync operation. This is constrained to the case +struct MmaSyncBuilder { + MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId) + : b(b), loc(loc), laneId(laneId) {} + + using IndexCalculator = + std::function(MLIRContext *)>; + + /// Create the mma.sync operation corresponding to `linalgOp` along with all + /// the supporting load/store and vector operations. + FailureOr buildMmaSync(LinalgOp linalgOp); + +private: + struct MmaSyncInfo { + std::tuple indexFns; + std::tuple, SmallVector, SmallVector> + vectorShapes; + SmallVector mmaShape; + bool tf32Enabled; + }; + + /// Return the specific index calculator for the given `linalgOp` or failure + /// if the op is not supported. This is the toplevel switch that should just + /// be Tablegen'd in the future. + FailureOr getIndexCalculators(ArrayRef opShape, + TypeRange elementalTypes); + + //===--------------------------------------------------------------------===// + // Instruction-specific row, column indexing expression builders. + // These should all be declaratively specified via Tablegen in the future. + // The Tablegen specification should be as straightforward as possible to + // only model the existing size and type combinations. + //===--------------------------------------------------------------------===// + // + // TODO: Tablegen all this. + //===--------------------------------------------------------------------===// + // m16n8k4 tf32 case. + //===--------------------------------------------------------------------===// + /// From the NVIDIA doc: + /// groupID = %laneid >> 2 + /// threadIDInGroup = %laneid % 4 + /// row = groupID for a0 + /// groupID + 8 for a1 + /// col = threadIDInGroup + static SmallVector m16n8k4tf32Lhs(MLIRContext *ctx); + /// From the NVIDIA doc: + /// groupID = %laneid >> 2 + /// threadIDInGroup = %laneid % 4 + /// row = threadIDInGroup + /// col = groupID + static SmallVector m16n8k4tf32Rhs(MLIRContext *ctx); + /// From the NVIDIA doc: + /// groupID = %laneid >> 2 + /// threadIDInGroup = %laneid % 4 + /// row = groupID for c0 and c1 + /// groupID + 8 for c2 and c3 + /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} + static SmallVector m16n8k4tf32Res(MLIRContext *ctx); + + //===--------------------------------------------------------------------===// + /// Helper functions to create customizable load and stores operations. The + /// specific shapes of each MMA instruction are passed via the + /// IndexCalculator callback. + //===--------------------------------------------------------------------===// + /// Build a list of memref.load operations indexed at `(row, col)` indices + /// that make sense for a particular MMA instruction and specified via the + /// IndexCalculator callback. + SmallVector buildMemrefLoads(OpBuilder &b, Location loc, + OpFoldResult laneId, Value memref, + IndexCalculator indexFn); + + /// Perform a distributed load of a vector operand of `vectorShape` for a + /// particular MMA instruction whose `(row, col)` indices are specified via + /// the IndexCalculator callback. Each `laneId` loads the subportion of the + /// data that makes sense for the particular MMA operation. + /// The `vectorShape` matches existing NVGPU dialect op specification but + /// could also be flattened in the future if needed for simplification. + Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc, + OpFoldResult laneId, Value memref, + IndexCalculator indexFn, + ArrayRef vectorShape); + + /// Build a list of memref.store operations indexed at `(row, col)` indices + /// that make sense for a particular MMA instruction and specified via the + /// IndexCalculator callback. + SmallVector buildMemrefStores(OpBuilder &b, Location loc, + ValueRange toStore, + OpFoldResult laneId, Value memref, + IndexCalculator indexFn); + + /// Perform a distributed store of a vector operand of `vectorShape` for a + /// particular MMA instruction whose `(row, col)` indices are specified via + /// the IndexCalculator callback. Each `laneId` loads the subportion of the + /// data that makes sense for the particular MMA operation. + /// The `vectorShape` matches existing NVGPU dialect op specification but + /// could also be flattened in the future if needed for simplification. + SmallVector buildMmaSyncMemrefStoreOperand( + OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, + Value memref, IndexCalculator indexFn, ArrayRef vectorShape); + + OpBuilder &b; + Location loc; + OpFoldResult laneId; +}; + +SmallVector MmaSyncBuilder::m16n8k4tf32Lhs(MLIRContext *ctx) { + auto dim = getAffineDimExpr(0, ctx); + AffineExpr groupID = dim.floorDiv(4); + AffineExpr threadIDInGroup = dim % 4; + return {RowColIndexing{groupID, threadIDInGroup}, + RowColIndexing{groupID + 8, threadIDInGroup}}; +} + +SmallVector MmaSyncBuilder::m16n8k4tf32Rhs(MLIRContext *ctx) { + auto dim = getAffineDimExpr(0, ctx); + AffineExpr groupID = dim.floorDiv(4); + AffineExpr threadIDInGroup = dim % 4; + return {RowColIndexing{threadIDInGroup, groupID}}; +} + +SmallVector MmaSyncBuilder::m16n8k4tf32Res(MLIRContext *ctx) { + auto dim = getAffineDimExpr(0, ctx); + AffineExpr groupID = dim.floorDiv(4); + AffineExpr threadIDInGroup = dim % 4; + return {RowColIndexing{groupID, threadIDInGroup * 2 + 0}, + RowColIndexing{groupID, threadIDInGroup * 2 + 1}, + RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, + RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}}; +} + +//===--------------------------------------------------------------------===// +/// Helper functions to create customizable load and stores operations. The +/// specific shapes of each MMA instruction are passed via the +/// IndexCalculator callback. +//===--------------------------------------------------------------------===// + +template +static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, + ReduceFn reduceFn) { + VectorType vectorType = vector.getType().cast(); + auto vectorShape = vectorType.getShape(); + auto strides = computeStrides(vectorShape); + for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) { + auto indices = delinearize(idx, strides); + reduceFn(applyFn(vector, idx, indices), idx, indices); + } +} + +SmallVector MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc, + OpFoldResult laneId, + Value memref, + IndexCalculator indexFn) { + auto aff = [&](AffineExpr e) { + return affine::makeComposedFoldedAffineApply(b, loc, e, laneId); + }; + SmallVector res; + SmallVector indexings = indexFn(b.getContext()); + ArrayRef indexingsRef{indexings.begin(), indexings.end()}; + for (int64_t i = 0, e = indexings.size(); i < e;) { + auto indexing = indexings[i]; + Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); + Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); + auto load = b.create(loc, memref, ValueRange{row, col}); + res.push_back(load); + ++i; + } + return res; +} + +Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand( + OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, + IndexCalculator indexFn, ArrayRef vectorShape) { + auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn); + + Type elementType = getElementTypeOrSelf(memref.getType()); + auto vt = VectorType::get(vectorShape, elementType); + Value res = b.create(loc, vt, loads[0]); + foreachIndividualVectorElement( + res, + /*applyFn=*/ + [&](Value v, int64_t linearIdx, ArrayRef indices) { + return loads[linearIdx]; + }, + /*reduceFn=*/ + [&](Value v, int64_t linearIdx, ArrayRef indices) { + res = b.create(loc, v, res, indices); + }); + + return res; +} + +SmallVector +MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc, + ValueRange toStore, OpFoldResult laneId, + Value memref, IndexCalculator indexFn) { + auto aff = [&](AffineExpr e) { + return affine::makeComposedFoldedAffineApply(b, loc, e, laneId); + }; + SmallVector res; + for (auto [indexing, val] : + llvm::zip_equal(indexFn(b.getContext()), toStore)) { + Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); + Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); + Operation *store = + b.create(loc, val, memref, ValueRange{row, col}); + res.push_back(store); + } + return res; +} + +SmallVector MmaSyncBuilder::buildMmaSyncMemrefStoreOperand( + OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, + Value memref, IndexCalculator indexFn, ArrayRef vectorShape) { + SmallVector toStore; + toStore.reserve(32); + foreachIndividualVectorElement( + vectorToStore, + /*applyFn=*/ + [&](Value v, int64_t linearIdx, ArrayRef indices) { + return b.create(loc, vectorToStore, indices); + }, + /*reduceFn=*/ + [&](Value v, int64_t linearIdx, ArrayRef indices) { + toStore.push_back(v); + }); + return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn); +} + +static std::tuple, SmallVector, + SmallVector> +makeVectorShapes(ArrayRef lhs, ArrayRef rhs, + ArrayRef res) { + SmallVector vlhs{lhs.begin(), lhs.end()}; + SmallVector vrhs{rhs.begin(), rhs.end()}; + SmallVector vres{res.begin(), res.end()}; + return std::make_tuple(vlhs, vrhs, vres); +} + +FailureOr +MmaSyncBuilder::getIndexCalculators(ArrayRef opShape, + TypeRange elementalTypes) { + // TODO: Tablegen all this. + Type f32 = b.getF32Type(); + if (opShape == ArrayRef{16, 8, 4} && + elementalTypes == TypeRange{f32, f32, f32}) { + return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs, + &MmaSyncBuilder::m16n8k4tf32Rhs, + &MmaSyncBuilder::m16n8k4tf32Res), + makeVectorShapes({2, 1}, {1, 1}, {2, 2}), + SmallVector{opShape.begin(), opShape.end()}, + /*tf32Enabled=*/true}; + } + return failure(); +} + +FailureOr MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { + Value lhsMemref = linalgOp.getDpsInputOperand(0)->get(); + Value rhsMemref = linalgOp.getDpsInputOperand(1)->get(); + Value resMemref = linalgOp.getDpsInitOperand(0)->get(); + + int64_t m = cast(lhsMemref.getType()).getShape()[0]; + int64_t n = cast(rhsMemref.getType()).getShape()[1]; + int64_t k = cast(lhsMemref.getType()).getShape()[1]; + Type lhsType = getElementTypeOrSelf(lhsMemref.getType()); + Type rhsType = getElementTypeOrSelf(rhsMemref.getType()); + Type resType = getElementTypeOrSelf(resMemref.getType()); + + FailureOr maybeInfo = + getIndexCalculators({m, n, k}, {lhsType, rhsType, resType}); + if (failed(maybeInfo)) + return failure(); + + MmaSyncInfo info = *maybeInfo; + auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; + auto [lhsShape, rhsShape, resShape] = info.vectorShapes; + Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref, + lhsIndexFn, lhsShape); + Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref, + rhsIndexFn, rhsShape); + Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref, + resIndexFn, resShape); + res = b.create(loc, lhs, rhs, res, info.mmaShape, + info.tf32Enabled); + buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn, + resShape); + return res.getDefiningOp(); +} + +DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( + LinalgOp linalgOp, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + + TrackingListener listener(state, *this); + IRRewriter rewriter(linalgOp->getContext(), &listener); + rewriter.setInsertionPoint(linalgOp); + + bool fail = true; + // TODO: more robust detection of matmulOp, with transposes etc. + if (auto matmulOp = isa(linalgOp.getOperation())) { + Location loc = linalgOp.getLoc(); + // TODO: more robust computation of laneId, for now assume a single warp. + Value laneId = rewriter.create( + loc, rewriter.getIndexType(), gpu::Dimension::x); + if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp))) + fail = false; + } + + if (fail) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "unsupported target op: " << linalgOp; + diag.attachNote(linalgOp->getLoc()) << "target op"; + return diag; + } + + rewriter.eraseOp(linalgOp); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class NVGPUTransformDialectExtension + : public transform::TransformDialectExtension< + NVGPUTransformDialectExtension> { +public: + NVGPUTransformDialectExtension() { + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" + +void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @matmul_16x8x4xf32_global +func.func @matmul_16x8x4xf32_global( + %A: memref<16x4xf32>, %B: memref<4x8xf32>, %C: memref<16x8xf32>) { + + // CHECK-COUNT-2: memref.load {{.*}} : memref<16x4xf32> + // CHECK-COUNT-2: vector.insert {{.*}} : f32 into vector<2x1xf32> + // CHECK-COUNT-1: memref.load {{.*}} : memref<4x8xf32> + // CHECK-COUNT-1: vector.insert {{.*}} : f32 into vector<1x1xf32> + // CHECK-COUNT-4: memref.load {{.*}} : memref<16x8xf32> + // CHECK-COUNT-4: vector.insert {{.*}} : f32 into vector<2x2xf32> + // + // CHECK: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 4], tf32Enabled} + // CHECK-SAME: : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + // + // CHECK-COUNT-4: vector.extract %{{.*}} : vector<2x2xf32> + // CHECK-COUNT-4: memref.store %{{.*}} : memref<16x8xf32> + linalg.matmul ins(%A, %B: memref<16x4xf32>, memref<4x8xf32>) + outs(%C: memref<16x8xf32>) + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.nvgpu.rewrite_matmul_as_mma_sync %matmul + : (!transform.any_op) -> () +} diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f32.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f32.mlir @@ -0,0 +1,168 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -test-transform-dialect-interpreter \ +// RUN: | mlir-opt -test-transform-dialect-erase-schedule \ +// RUN: | mlir-opt -gpu-kernel-outlining \ +// RUN: | mlir-opt -convert-scf-to-cf \ +// RUN: | mlir-opt -convert-vector-to-llvm \ +// RUN: | mlir-opt -convert-math-to-llvm \ +// RUN: | mlir-opt -expand-strided-metadata \ +// RUN: | mlir-opt -lower-affine \ +// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \ +// RUN: | mlir-opt -convert-arith-to-llvm \ +// RUN: | mlir-opt -finalize-memref-to-llvm \ +// RUN: | mlir-opt -convert-func-to-llvm \ +// RUN: | mlir-opt -canonicalize \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_80}))' \ +// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \ +// RUN: | mlir-opt -gpu-to-llvm \ +// RUN: | mlir-opt -convert-func-to-llvm \ +// RUN: | mlir-opt -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +!lhs_memref_type = memref<16x4xf32> +!rhs_memref_type = memref<4x8xf32> +!res_memref_type = memref<16x8xf32> + +func.func @compute_linspace_val(%ridx: index, %cidx: index, %strideCidx: index) -> f32 { + %r = arith.index_cast %ridx : index to i32 + %c = arith.index_cast %cidx : index to i32 + %strideC = arith.index_cast %strideCidx : index to i32 + %2 = arith.muli %r, %strideC : i32 + %3 = arith.addi %c, %2 : i32 + %4 = arith.sitofp %3 : i32 to f32 + return %4: f32 +} + +func.func @main() { + %lhs = memref.alloc() : !lhs_memref_type + %rhs = memref.alloc() : !rhs_memref_type + %res = memref.alloc() : !res_memref_type + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %M = memref.dim %res, %c0 : !res_memref_type + %N = memref.dim %res, %c1 : !res_memref_type + %K = memref.dim %lhs, %c1 : !lhs_memref_type + + %f1 = arith.constant 1.0e+00 : f32 + %f0 = arith.constant 0.0e+00 : f32 + %c32 = arith.constant 32 : index + + // Intialize the lhs matrix with a linspace function. + scf.for %r = %c0 to %M step %c1 { + scf.for %c = %c0 to %K step %c1 { + %idx = func.call @compute_linspace_val(%r, %c, %K) : (index, index, index) -> f32 + memref.store %idx, %lhs[%r, %c] : !lhs_memref_type + } + } + // Intialize the rhs matrix with a linspace function. + scf.for %r = %c0 to %K step %c1 { + scf.for %c = %c0 to %N step %c1 { + %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f32 + memref.store %idx, %rhs[%r, %c] : !rhs_memref_type + } + } + // Intialize the rhs matrix with a linspace function. + scf.for %r = %c0 to %M step %c1 { + scf.for %c = %c0 to %N step %c1 { + %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f32 + memref.store %idx, %res[%r, %c] : !res_memref_type + } + } + + %ulhs = memref.cast %lhs : !lhs_memref_type to memref<*xf32> + %urhs = memref.cast %rhs : !rhs_memref_type to memref<*xf32> + %ures = memref.cast %res : !res_memref_type to memref<*xf32> + gpu.host_register %ulhs : memref<*xf32> + gpu.host_register %urhs : memref<*xf32> + gpu.host_register %ures : memref<*xf32> + + // Print the memrefs before computation. + call @printMemrefF32(%ulhs) : (memref<*xf32>) -> () + // CHECK: [0, 1, 2, 3], + // CHECK: [4, 5, 6, 7], + // CHECK: [8, 9, 10, 11], + // CHECK: [12, 13, 14, 15], + // CHECK: [16, 17, 18, 19], + // CHECK: [20, 21, 22, 23], + // CHECK: [24, 25, 26, 27], + // CHECK: [28, 29, 30, 31], + // CHECK: [32, 33, 34, 35], + // CHECK: [36, 37, 38, 39], + // CHECK: [40, 41, 42, 43], + // CHECK: [44, 45, 46, 47], + // CHECK: [48, 49, 50, 51], + // CHECK: [52, 53, 54, 55], + // CHECK: [56, 57, 58, 59], + // CHECK: [60, 61, 62, 63] + + call @printMemrefF32(%urhs) : (memref<*xf32>) -> () + // CHECK: [0, 1, 2, 3, 4, 5, 6, 7], + // CHECK: [8, 9, 10, 11, 12, 13, 14, 15], + // CHECK: [16, 17, 18, 19, 20, 21, 22, 23], + // CHECK: [24, 25, 26, 27, 28, 29, 30, 31] + + call @printMemrefF32(%ures) : (memref<*xf32>) -> () + // CHECK: [0, 1, 2, 3, 4, 5, 6, 7], + // CHECK: [8, 9, 10, 11, 12, 13, 14, 15], + // CHECK: [16, 17, 18, 19, 20, 21, 22, 23], + // CHECK: [24, 25, 26, 27, 28, 29, 30, 31], + // CHECK: [32, 33, 34, 35, 36, 37, 38, 39], + // CHECK: [40, 41, 42, 43, 44, 45, 46, 47], + // CHECK: [48, 49, 50, 51, 52, 53, 54, 55], + // CHECK: [56, 57, 58, 59, 60, 61, 62, 63], + // CHECK: [64, 65, 66, 67, 68, 69, 70, 71], + // CHECK: [72, 73, 74, 75, 76, 77, 78, 79], + // CHECK: [80, 81, 82, 83, 84, 85, 86, 87], + // CHECK: [88, 89, 90, 91, 92, 93, 94, 95], + // CHECK: [96, 97, 98, 99, 100, 101, 102, 103], + // CHECK: [104, 105, 106, 107, 108, 109, 110, 111], + // CHECK: [112, 113, 114, 115, 116, 117, 118, 119], + // CHECK: [120, 121, 122, 123, 124, 125, 126, 127] + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) { + + linalg.matmul ins(%lhs, %rhs: !lhs_memref_type, !rhs_memref_type) + outs(%res: !res_memref_type) + + gpu.terminator + } + + + // Print the result memref after computation. + call @printMemrefF32(%ures) : (memref<*xf32>) -> () + + // CHECK: [112, 119, 126, 133, 140, 147, 154, 161], + // CHECK: [312, 335, 358, 381, 404, 427, 450, 473], + // CHECK: [512, 551, 590, 629, 668, 707, 746, 785], + // CHECK: [712, 767, 822, 877, 932, 987, 1042, 1097], + // CHECK: [912, 983, 1054, 1125, 1196, 1267, 1338, 1409], + // CHECK: [1112, 1199, 1286, 1373, 1460, 1547, 1634, 1721], + // CHECK: [1312, 1415, 1518, 1621, 1724, 1827, 1930, 2033], + // CHECK: [1512, 1631, 1750, 1869, 1988, 2107, 2226, 2345], + // CHECK: [1712, 1847, 1982, 2117, 2252, 2387, 2522, 2657], + // CHECK: [1912, 2063, 2214, 2365, 2516, 2667, 2818, 2969], + // CHECK: [2112, 2279, 2446, 2613, 2780, 2947, 3114, 3281], + // CHECK: [2312, 2495, 2678, 2861, 3044, 3227, 3410, 3593], + // CHECK: [2512, 2711, 2910, 3109, 3308, 3507, 3706, 3905], + // CHECK: [2712, 2927, 3142, 3357, 3572, 3787, 4002, 4217], + // CHECK: [2912, 3143, 3374, 3605, 3836, 4067, 4298, 4529], + // CHECK: [3112, 3359, 3606, 3853, 4100, 4347, 4594, 4841] + + return +} + +func.func private @printMemrefF32(memref<*xf32>) + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.nvgpu.rewrite_matmul_as_mma_sync %matmul + : (!transform.any_op) -> () +}