diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -246,6 +246,8 @@ return CombiningKind::ADD; } }]; + + let hasCanonicalizer = 1; } def Vector_ReductionOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" @@ -658,6 +659,66 @@ return shape; } +/// Return a fused vector::ContractionOp which represents a patterns such as: +/// +/// ```mlir +/// %c0 = vector.constant 0: ... +/// %c = vector.contract %a, %b, %c0: ... +/// %e = add %c, %d: ... +/// ``` +/// +/// by: +/// +/// ```mlir +/// %e = vector.contract %a, %b, %d: ... +/// ``` +/// +/// Return null if the canonicalization does not apply. +// TODO: This should be a folding of Add into Contract in core but while they +// live in different dialects, it is not possible without unnatural +// dependencies. +template +struct CanonicalizeContractAdd : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AddOpType addOp, + PatternRewriter &rewriter) const override { + auto canonicalize = [&](Value maybeContraction, + Value otherOperand) -> vector::ContractionOp { + vector::ContractionOp contractionOp = + dyn_cast_or_null( + maybeContraction.getDefiningOp()); + if (!contractionOp) + return vector::ContractionOp(); + if (auto maybeZero = dyn_cast_or_null( + contractionOp.acc().getDefiningOp())) { + if (maybeZero.value() == + rewriter.getZeroAttr(contractionOp.acc().getType())) { + BlockAndValueMapping bvm; + bvm.map(contractionOp.acc(), otherOperand); + auto newContraction = + cast(rewriter.clone(*contractionOp, bvm)); + rewriter.replaceOp(addOp, newContraction.getResult()); + return newContraction; + } + } + return vector::ContractionOp(); + }; + + Value a = addOp->getOperand(0), b = addOp->getOperand(1); + vector::ContractionOp contract = canonicalize(a, b); + contract = contract ? contract : canonicalize(b, a); + return success(); + } +}; + +void ContractionOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results + .insert, CanonicalizeContractAdd>( + context); +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -710,3 +710,49 @@ memref, vector<16xi1>, vector<16xf32> into vector<16xf32> return } + +// ----- + +#contraction_accesses0 = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#contraction_trait0 = { + indexing_maps = #contraction_accesses0, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @contractions +// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: vector<2x3xf32> +// CHECK-SAME: %[[B:[0-9a-zA-Z]+]]: vector<3x4xf32> +// CHECK-SAME: %[[C:[0-9a-zA-Z]+]]: vector<2x4xf32> +// CHECK-SAME: %[[A_I8:[0-9a-zA-Z]+]]: vector<2x3xi8> +// CHECK-SAME: %[[B_I8:[0-9a-zA-Z]+]]: vector<3x4xi8> +// CHECK-SAME: %[[C_I8:[0-9a-zA-Z]+]]: vector<2x4xi8> +func @contractions(%a: vector<2x3xf32>, %b: vector<3x4xf32>, %c: vector<2x4xf32>, + %a_i8: vector<2x3xi8>, %b_i8: vector<3x4xi8>, %c_i8: vector<2x4xi8>) + -> (vector<2x4xf32>, vector<2x4xi8>) +{ + // CHECK-NOT: constant + %vf_0 = constant dense <0.0>: vector<2x4xf32> + // CHECK-NOT: addf + // CHECK: %[[D:.*]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] + %0 = vector.contract #contraction_trait0 %a, %b, %vf_0: + vector<2x3xf32>, vector<3x4xf32> into vector<2x4xf32> + // CHECK-NOT: addf + %1 = addf %0, %c: vector<2x4xf32> + + // CHECK-NOT: constant + %vi8_0 = constant dense <0>: vector<2x4xi8> + // CHECK-NOT: addi + // CHECK: %[[D_I8:.*]] = vector.contract {{.*}} %[[A_I8]], %[[B_I8]], %[[C_I8]] + %i8_0 = vector.contract #contraction_trait0 %a_i8, %b_i8, %vi8_0: + vector<2x3xi8>, vector<3x4xi8> into vector<2x4xi8> + // CHECK-NOT: addi + %i8_1 = addi %i8_0, %c_i8: vector<2x4xi8> + + // CHECK: return %[[D]], %[[D_I8]] + return %1, %i8_1: vector<2x4xf32>, vector<2x4xi8> +} +