Index: mlir/include/mlir/Dialect/Vector/IR/VectorOps.h =================================================================== --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -122,6 +122,10 @@ bool force32BitVectorIndices, PatternBenefit benefit = 1); +/// Collect a set of patterns that fold arithmetic extension on floating point +/// into vector contract for the backends with native support. +void populateFoldArithExtensionPatterns(RewritePatternSet &patterns); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); Index: mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1212,8 +1212,54 @@ FilterConstraintType filter; }; +/// Pattern to fold arithmetic extensions on floating point data types into +/// vector contraction operations. linalg.matmul introduces arithmetic +/// extensions on its operands. Please mlir snippets below for more details. +/// ```mlir +/// "linalg.matmul"(%lhs, %rhs, %acc) ({ +/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): +/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 +/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 +/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 +/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 +/// "linalg.yield"(%acc) : (f32) -> () +/// }) +/// ``` +/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor +/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`. +/// This pattern folds the arithmetic extensions into the vector contraction and +/// enables the usage of native mixed precision Tensor Core instructions. +struct FoldArithExtIntoContractionOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + + auto lhsDefOp = contractOp.getLhs().getDefiningOp(); + auto rhsDefOp = contractOp.getRhs().getDefiningOp(); + + if (!lhsDefOp || !rhsDefOp) { + return rewriter.notifyMatchFailure(contractOp, + "no defining op on contract operands"); + } + + rewriter.replaceOpWithNewOp( + contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0), + contractOp.getAcc(), contractOp.getIndexingMapsAttr(), + contractOp.getIteratorTypesAttr()); + + return success(); + } +}; + } // namespace +void mlir::vector::populateFoldArithExtensionPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + void mlir::vector::populateVectorMaskMaterializationPatterns( RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit) { Index: mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt -split-input-file -test-fold-arith-extf-into-vector-contract-patterns %s | FileCheck %s + + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func.func @fold_arith_extf_into_contract +// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xf16>, %[[ARG1:.*]]: vector<64x64xf16>, %[[ARG2:.*]]: vector<64x64xf32>) +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32> +// CHECK-NEXT: return %[[R]] : vector<64x64xf32> +func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> { + %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32> + %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32> + %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32> + return %result : vector<64x64xf32> +} Index: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -709,6 +709,32 @@ } }; +struct TestFoldAirthExtensionIntoVectorContractPatterns + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestFoldAirthExtensionIntoVectorContractPatterns) + + StringRef getArgument() const final { + return "test-fold-arith-extf-into-vector-contract-patterns"; + } + StringRef getDescription() const final { + return "Test patterns that fold arithmetic extension ops into vector " + "contract ops"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateFoldArithExtensionPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; } // namespace namespace mlir { @@ -745,6 +771,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir