diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -530,13 +530,19 @@ if (maskableOp.isMasked()) return failure(); - if (reductionOp.getVectorType().getDimSize(0) != 1) + auto vectorType = reductionOp.getVectorType(); + if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1) return failure(); Location loc = reductionOp.getLoc(); - Value result = rewriter.create(loc, reductionOp.getType(), - reductionOp.getVector(), - rewriter.getI64ArrayAttr(0)); + Value result; + if (vectorType.getRank() == 0) { + result = rewriter.create(loc, reductionOp.getVector()); + } else { + result = rewriter.create(loc, reductionOp.getType(), + reductionOp.getVector(), + rewriter.getI64ArrayAttr(0)); + } if (Value acc = reductionOp.getAcc()) result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2157,3 +2157,13 @@ %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32> return %1 : f32 } + +// ----- + +// CHECK-LABEL: func.func @fold_0d_vector_reduction +func.func @fold_0d_vector_reduction(%arg0: vector) -> f32 { + // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector + // CHECK-NEXT: return %[[RES]] : f32 + %0 = vector.reduction , %arg0 : vector into f32 + return %0 : f32 +}