diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -396,6 +396,7 @@ let assemblyFormat = "$kind `,` $source `,` $acc attr-dict $reduction_dims `:` type($source) `to` type($dest)"; let hasFolder = 1; + let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -309,6 +309,40 @@ return success(); } +namespace { +// Only unit dimensions that are being reduced are folded. If the dimension is +// unit, but not reduced, it is not folded, thereby keeping the output type the +// same. If not all dimensions which are reduced are of unit dimension, this +// transformation does nothing. This is just a generalization of +// ElideSingleElementReduction for ReduceOp. +struct ElideUnitDimsInMultiDimReduction + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, + PatternRewriter &rewriter) const override { + ArrayRef shape = reductionOp.getSourceVectorType().getShape(); + for (auto dim : enumerate(shape)) { + if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1) + return failure(); + } + Location loc = reductionOp.getLoc(); + Value acc = reductionOp.getAcc(); + Value cast = rewriter.create( + loc, reductionOp.getDestType(), reductionOp.getSource()); + Value result = vector::makeArithReduction(rewriter, loc, + reductionOp.getKind(), acc, cast); + rewriter.replaceOp(reductionOp, result); + return success(); + } +}; +} // namespace + +void MultiDimReductionOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1348,6 +1348,31 @@ // ----- +// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions( +// CHECK-SAME: %[[SOURCE:.+]]: vector<5x1x4x1x20xf32>, %[[ACC:.+]]: vector<5x4x20xf32> +func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x4x20xf32>) -> vector<5x4x20xf32> { +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[SOURCE]] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> +// CHECK: %[[RESULT:.+]] = arith.mulf %[[ACC]], %[[CAST]] : vector<5x4x20xf32> + %0 = vector.multi_reduction , %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> + +// CHECK: return %[[RESULT]] : vector<5x4x20xf32> + return %0 : vector<5x4x20xf32> +} + +// ----- + +// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions_fail( +// CHECK-SAME: %[[SRC:.+]]: vector<5x1x4x1x20xf32>, %[[ACCUM:.+]]: vector<5x1x1x20xf32> +func.func @vector_multi_reduction_unit_dimensions_fail(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x1x1x20xf32>) -> vector<5x1x1x20xf32> { +// CHECK: %[[RES:.+]] = vector.multi_reduction , %[[SRC]], %[[ACCUM]] [2] : vector<5x1x4x1x20xf32> to vector<5x1x1x20xf32> + %0 = vector.multi_reduction , %source, %acc [2] : vector<5x1x4x1x20xf32> to vector<5x1x1x20xf32> + +// CHECK: return %[[RES]] : vector<5x1x1x20xf32> + return %0 : vector<5x1x1x20xf32> +} + +// ----- + // CHECK-LABEL: func @insert_strided_slice_full_range // CHECK-SAME: %[[SOURCE:.+]]: vector<16x16xf16>, %{{.+}}: vector<16x16xf16> func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<16x16xf16>) -> vector<16x16xf16> {