Index: mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td =================================================================== --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -110,11 +110,22 @@ (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> ``` }]; - let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, - AnyVector:$matrixC, I64ArrayAttr:$mmaShape); + let arguments = (ins AnyVector:$matrixA, + AnyVector:$matrixB, + AnyVector:$matrixC, + I64ArrayAttr:$mmaShape, + OptionalAttr:$tf32Enabled + ); let results = (outs AnyVector:$res); + let builders = [ + OpBuilder<(ins "Value":$matrixA, + "Value":$matrixB, + "Value":$matrixC, + "ArrayAttr":$mmaShape)> + ]; + let assemblyFormat = [{ `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) Index: mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h @@ -19,6 +19,10 @@ namespace mlir { namespace nvgpu { +/// +/// Passes +/// + /// Optimizes vectorized accesses to a shared memory buffer specified by /// memrefValue. This transformation assumes the following: /// 1) All relevant accesses to `memrefValue` are contained with `parentOp`. @@ -41,6 +45,29 @@ mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue); +/// +/// Rewrites patterns +/// + +//===----------------------------------------------------------------------===// +// NVGPU transformation options exposed as auxiliary structs. +//===----------------------------------------------------------------------===// +/// Enum to control the lowering of `nvgpu.mmasync`. +enum class MmaSyncF32Lowering { TF32 = 0, TF32x3 = 1, Unkown = 2 }; + +/// Collect patterns to convert mma.sync on f32 input and rewrite +/// to use tensor cores with user provided level of accuracy: +/// (a) tf32 (1 mma.sync per warp-level matrix-multiply-accumulate) +/// (b) tf32x3 (3 mma.sync per warp-level matrix-multiply-accumulate) +/// Typically, tf32 tensor core acceleration comes at a cost +/// of accuracy from missing precision bits. While f32 has 23 precision +/// bits, tf32 has only 10 precision bits. tf32x3 aims to recover the +/// precision bits by spliting each operand into two tf32 values +/// and issue three mma.sync tensor core operations. +void populateMmaSyncF32ToTF32Patterns( + RewritePatternSet &patterns, + nvgpu::MmaSyncF32Lowering precision = nvgpu::MmaSyncF32Lowering::TF32); + } // namespace nvgpu } // namespace mlir Index: mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp =================================================================== --- mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -275,10 +275,14 @@ NVVM::MMATypes ptxTypeB; Optional ptxTypeC = NVVM::MmaOp::inferOperandMMAType( cType.getElementType(), /*isAccumulator=*/true); - if (!ptxTypeC) { + if (!ptxTypeC) return op->emitError( "could not infer the PTX type for the accumulator/result"); - } + + // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32). + bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); + if (aType.getElementType().isF32() && !tf32Enabled) + return failure(); Optional overflow(llvm::None); if (aType.getElementType().isInteger(8)) { Index: mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp =================================================================== --- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -687,8 +687,8 @@ int64_t m = op.getLhs().getType().cast().getShape()[0]; int64_t n = op.getRhs().getType().cast().getShape()[0]; int64_t k = op.getLhs().getType().cast().getShape()[1]; - Value matmul = b.create( - op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k})); + Value matmul = b.create(op.getLoc(), opA, opB, opC, + b.getI64ArrayAttr({m, n, k})); valueMapping[op.getResult()] = matmul; return success(); } Index: mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp =================================================================== --- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -91,6 +91,12 @@ //===----------------------------------------------------------------------===// // NVGPU_MmaSyncOp //===----------------------------------------------------------------------===// +void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, Value matrixA, + Value matrixB, Value matrixC, ArrayAttr mmaShape) { + build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, + mmaShape, UnitAttr()); +} LogicalResult MmaSyncOp::verify() { @@ -122,6 +128,9 @@ // vector element type Type aType = aVector.getElementType(); + // tensor float32 (TF32) enabled + bool tf32Enabled = getOperation()->hasAttr(getTf32EnabledAttrName()); + // nvgpu.mma.sync shape (per 32 threads or per warp) int64_t m = getMmaShape()[0].cast().getInt(); int64_t n = getMmaShape()[1].cast().getInt(); @@ -163,6 +172,10 @@ return emitOpError() << "expected " << m * n << " warp-wide matrix C elements"; + // verify tf32 tensor cores are enabled for only F32 datatype + if (tf32Enabled && !(aType.isF32())) + return emitOpError() << "expected tf32 tensor cores only for F32 operands"; + // // Extended verification // Index: mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRNVGPUTransforms - OptimizeSharedMemory.cpp + OptimizeSharedMemory.cpp + MmaSyncTF32Transform.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU Index: mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -0,0 +1,73 @@ +//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync +// operations on f32 input datatype +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Passes.h" +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace mlir::nvgpu; + +namespace { + +struct MmaSyncF32ToTF32Pattern : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + MmaSyncF32ToTF32Pattern(MLIRContext *context, + nvgpu::MmaSyncF32Lowering precision) + : OpRewritePattern(context, /*benifit*/ 1), + precision(precision) {} + + LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op, + PatternRewriter &rewrite) const override { + Location location = op->getLoc(); + + if (op->hasAttr(op.getTf32EnabledAttrName())) + return failure(); + + if (precision == MmaSyncF32Lowering::Unkown) + return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with " + "unknown precision level"); + + if (precision == MmaSyncF32Lowering::TF32x3) + return emitError(location, "TF32x3 is not supported at the moment " + "for nvgpu.mma.sync on f32 datatype"); + + if (precision == MmaSyncF32Lowering::TF32) + op.setTf32EnabledAttr(rewrite.getUnitAttr()); + + return success(); + } + +private: + /// Precision for F32 Tensor Cores (TF32 or TF32x3) + nvgpu::MmaSyncF32Lowering precision; +}; + +} // namespace + +void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns( + RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) { + + patterns.add(patterns.getContext(), precision); +} Index: mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -219,7 +219,7 @@ // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type // CHECK-SAME: shape = #nvvm.shape // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)> - %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32> // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)> // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)> Index: mlir/test/Dialect/NVGPU/invalid.mlir =================================================================== --- mlir/test/Dialect/NVGPU/invalid.mlir +++ mlir/test/Dialect/NVGPU/invalid.mlir @@ -76,6 +76,13 @@ } // ----- +func.func @m16n8k16_fp16_tf32Enabled(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // expected-error @+1 {{expected tf32 tensor cores only for F32 operands}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16], tf32Enabled} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} +// ----- + func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { // expected-error @+1 {{expected 128 warp-wide matrix A elements}} %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> Index: mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32" -split-input-file | FileCheck %s + +// CHECK-LABEL: m16n8k4_tf32 +func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // CHECK: nvgpu.mma.sync + // CHECK-SAME: tf32Enabled + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + return %d : vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: m16n8k8_tf32 +func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // CHECK: nvgpu.mma.sync + // CHECK-SAME: tf32Enabled + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + return %d : vector<2x2xf32> +} +// ----- Index: mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32x3" -split-input-file | FileCheck %s + +// CHECK-LABEL: m16n8k4_tf32 +func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + return %d : vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: m16n8k8_tf32 +func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + return %d : vector<2x2xf32> +} +// ----- Index: mlir/test/lib/Dialect/CMakeLists.txt =================================================================== --- mlir/test/lib/Dialect/CMakeLists.txt +++ mlir/test/lib/Dialect/CMakeLists.txt @@ -13,3 +13,4 @@ add_subdirectory(Tosa) add_subdirectory(Transform) add_subdirectory(Vector) +add_subdirectory(NVGPU) Index: mlir/test/lib/Dialect/NVGPU/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/test/lib/Dialect/NVGPU/CMakeLists.txt @@ -0,0 +1,20 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRNVGPUTestPasses + TestNVGPUTransforms.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRIR + MLIRAffineDialect + MLIRAnalysis + MLIRFuncDialect + MLIRGPUOps + MLIRLLVMDialect + MLIRMemRefDialect + MLIRPass + MLIRSCFDialect + MLIRTransformUtils + MLIRNVGPUDialect + ) + \ No newline at end of file Index: mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp =================================================================== --- /dev/null +++ mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp @@ -0,0 +1,76 @@ +//===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::nvgpu; + +namespace { + +struct TestMmaSyncF32ToTF32Patterns + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns) + + StringRef getArgument() const final { + return "test-nvgpu-mmasync-f32-to-tf32-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to convert mma.sync on f32 with tf32 precision"; + } + TestMmaSyncF32ToTF32Patterns() = default; + TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass) + : PassWrapper(pass) {} + + Option precision{ + *this, "precision", + llvm::cl::desc( + "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"), + llvm::cl::init("tf32")}; + + MmaSyncF32Lowering tf32Precision = + llvm::StringSwitch(precision) + .Case("tf32", MmaSyncF32Lowering::TF32) + .Case("tf32x3", MmaSyncF32Lowering::TF32x3) + .Default(MmaSyncF32Lowering::Unkown); + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + + populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestNvgpuLowerings() { + PassRegistration(); +} + +} // namespace test +} // namespace mlir \ No newline at end of file Index: mlir/tools/mlir-opt/CMakeLists.txt =================================================================== --- mlir/tools/mlir-opt/CMakeLists.txt +++ mlir/tools/mlir-opt/CMakeLists.txt @@ -35,6 +35,7 @@ MLIRTestTransforms MLIRTilingInterfaceTestPasses MLIRVectorTestPasses + MLIRNVGPUTestPasses ) endif() Index: mlir/tools/mlir-opt/mlir-opt.cpp =================================================================== --- mlir/tools/mlir-opt/mlir-opt.cpp +++ mlir/tools/mlir-opt/mlir-opt.cpp @@ -112,6 +112,7 @@ void registerTestTilingInterface(); void registerTestTransformDialectInterpreterPass(); void registerTestVectorLowerings(); +void registerTestNvgpuLowerings(); } // namespace test } // namespace mlir @@ -206,6 +207,7 @@ mlir::test::registerTestTilingInterface(); mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings(); + mlir::test::registerTestNvgpuLowerings(); } #endif