diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -311,6 +311,7 @@ // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional // operands. let hasCustomAssemblyFormat = 1; + let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -488,6 +488,30 @@ return llvm::to_vector<4>(getVectorType().getShape()); } +namespace { +struct ElideSingleElementReduction : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReductionOp reductionOp, + PatternRewriter &rewriter) const override { + if (reductionOp.getVectorType().getDimSize(0) != 1) + return failure(); + if (reductionOp.getAcc()) + return failure(); + + rewriter.replaceOpWithNewOp(reductionOp, reductionOp.getType(), + reductionOp.getVector(), + rewriter.getI64ArrayAttr(0)); + return success(); + } +}; +} // namespace + +void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // ContractionOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1528,3 +1528,32 @@ %1 = vector.extractelement %v[%i : i32] : vector<4xi32> return %1 : i32 } + +// ----- + +// CHECK-LABEL: func @reduce_one_element_vector +// CHECK-SAME: (%[[V:.+]]: vector<1xf32>) +// CHECK: %[[S:.+]] = vector.extract %[[V]][0] : vector<1xf32> +// CHECK: return %[[S]] : f32 +func @reduce_one_element_vector(%a : vector<1xf32>) -> f32 { + %s = vector.reduction , %a : vector<1xf32> into f32 + return %s : f32 +} + +// ----- + +// CHECK-LABEL: func @reduce_one_element_vector +// CHECK: vector.reduction +func @reduce_one_element_vector(%a : vector<1xf32>, %b: f32) -> f32 { + %s = vector.reduction , %a, %b : vector<1xf32> into f32 + return %s : f32 +} + +// ----- + +// CHECK-LABEL: func @reduce_one_element_vector +// CHECK: vector.reduction +func @reduce_one_element_vector(%a : vector<4xf32>) -> f32 { + %s = vector.reduction , %a : vector<4xf32> into f32 + return %s : f32 +}