diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -311,6 +311,7 @@ // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional // operands. let hasCustomAssemblyFormat = 1; + let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -488,6 +488,45 @@ return llvm::to_vector<4>(getVectorType().getShape()); } +namespace { +struct ElideSingleElementReduction : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReductionOp reductionOp, + PatternRewriter &rewriter) const override { + if (reductionOp.getVectorType().getDimSize(0) != 1) + return failure(); + + Location loc = reductionOp.getLoc(); + Value result = rewriter.create(loc, reductionOp.getType(), + reductionOp.getVector(), + rewriter.getI64ArrayAttr(0)); + + if (Value acc = reductionOp.getAcc()) { + assert(reductionOp.getType().isa()); + switch (reductionOp.getKind()) { + case CombiningKind::ADD: + result = rewriter.create(loc, result, acc); + break; + case CombiningKind::MUL: + result = rewriter.create(loc, result, acc); + break; + default: + assert(false && "invalid op!"); + } + } + + rewriter.replaceOp(reductionOp, result); + return success(); + } +}; +} // namespace + +void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // ContractionOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1561,3 +1561,47 @@ %1 = vector.extractelement %v[%i : i32] : vector<4xi32> return %1 : i32 } + +// ----- + +// CHECK-LABEL: func @reduce_one_element_vector_extract +// CHECK-SAME: (%[[V:.+]]: vector<1xf32>) +// CHECK: %[[S:.+]] = vector.extract %[[V]][0] : vector<1xf32> +// CHECK: return %[[S]] : f32 +func @reduce_one_element_vector_extract(%a : vector<1xf32>) -> f32 { + %s = vector.reduction , %a : vector<1xf32> into f32 + return %s : f32 +} + +// ----- + +// CHECK-LABEL: func @reduce_one_element_vector_addf +// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32) +// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32> +// CHECK: %[[S:.+]] = arith.addf %[[A]], %arg1 : f32 +// CHECK: return %[[S]] +func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 { + %s = vector.reduction , %a, %b : vector<1xf32> into f32 + return %s : f32 +} + +// ----- + +// CHECK-LABEL: func @reduce_one_element_vector_mulf +// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32) +// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32> +// CHECK: %[[S:.+]] = arith.mulf %[[A]], %arg1 : f32 +// CHECK: return %[[S]] +func @reduce_one_element_vector_mulf(%a : vector<1xf32>, %b: f32) -> f32 { + %s = vector.reduction , %a, %b : vector<1xf32> into f32 + return %s : f32 +} + +// ----- + +// CHECK-LABEL: func @dont_reduce_one_element_vector +// CHECK: vector.reduction +func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 { + %s = vector.reduction , %a : vector<4xf32> into f32 + return %s : f32 +}