Index: mlir/include/mlir/Dialect/Vector/IR/VectorOps.h =================================================================== --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -122,6 +122,10 @@ bool force32BitVectorIndices, PatternBenefit benefit = 1); +/// Collect a set of patterns that fold arithmetic extension on floating point +/// into vector contract for the backends with native support. +void populateFoldArithExtensionPatterns(RewritePatternSet &patterns); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); Index: mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1212,8 +1212,54 @@ FilterConstraintType filter; }; +/// Pattern to fold arithmetic extensions on floating point data types into +/// vector contraction operations. linalg.matmul introduces arithmetic +/// extensions on its operands. Please mlir snippets below for more details. +/// ```mlir +/// "linalg.matmul"(%lhs, %rhs, %acc) ({ +/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): +/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 +/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 +/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 +/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 +/// "linalg.yield"(%acc) : (f32) -> () +/// }) +/// ``` +/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor +/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`. +/// This pattern folds the arithmetic extensions into the vector contraction and +/// enables the usage of native mixed precision Tensor Core instructions. +struct FoldArithExtIntoContractionOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + auto lhsDefOp = contractOp.getLhs().getDefiningOp(); + auto rhsDefOp = contractOp.getRhs().getDefiningOp(); + + if (!lhsDefOp || !rhsDefOp) { + return rewriter.notifyMatchFailure(contractOp, + "no defining op on contract operands"); + } + + rewriter.replaceOpWithNewOp( + contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0), + contractOp.getAcc(), contractOp.getIndexingMapsAttr(), + contractOp.getIteratorTypesAttr()); + + return success(); + } +}; + } // namespace +void mlir::vector::populateFoldArithExtensionPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + void mlir::vector::populateVectorMaskMaterializationPatterns( RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit) { Index: mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(test-fold-arith-extf-into-vector-contract-patterns,convert-vector-to-gpu{use-nvgpu=true},cse))" | FileCheck %s + +//############################################################################################### +// FP16 input, F32 accumulation row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB) +//############################################################################################### + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row +func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16, #gpu.address_space>, %arg1: memref<32x64xf16, #gpu.address_space>, %arg2: memref<42x64xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %cst_f16 = arith.constant 0.000000e+00 : f16 + %cst_f32 = arith.constant 0.000000e+00 : f32 + + // CHECK-DAG: nvgpu.ldmatrix %arg0[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = false} + %A = vector.transfer_read %arg0[%c0, %c0], %cst_f16 {in_bounds = [true, true]} : memref<42x32xf16, #gpu.address_space>, vector<16x16xf16> + %A_f32 = arith.extf %A : vector<16x16xf16> to vector<16x16xf32> + + + // CHECK-DAG: nvgpu.ldmatrix %arg1[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = true} + %B = vector.transfer_read %arg1[%c0, %c0], %cst_f16 {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, #gpu.address_space>, vector<16x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<42x64xf32, #gpu.address_space>, vector<16x16xf32> + + %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> + %B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32> + %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32> + + // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32> + %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32> + vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space> + + + %B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> + %B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32> + %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32> + + // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32> + %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A_f32, %B1_f32, %C1 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32> + vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space> + + return +} Index: mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt -split-input-file -test-fold-arith-extf-into-vector-contract-patterns %s | FileCheck %s + + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @fold_arith_extf_into_contract +// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xf16>, %[[ARG1:.*]]: vector<64x64xf16>, %[[ARG2:.*]]: vector<64x64xf32>) +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32> +// CHECK-NEXT: return %[[R]] : vector<64x64xf32> +func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> { + %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32> + %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32> + %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32> + return %result : vector<64x64xf32> +} \ No newline at end of file Index: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -709,6 +710,32 @@ } }; +struct TestFoldArithExtensionIntoVectorContractPatterns + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestFoldArithExtensionIntoVectorContractPatterns) + + StringRef getArgument() const final { + return "test-fold-arith-extf-into-vector-contract-patterns"; + } + StringRef getDescription() const final { + return "Test patterns that fold arithmetic extension ops into vector " + "contract ops"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateFoldArithExtensionPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; } // namespace namespace mlir { @@ -745,6 +772,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir