diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -361,6 +361,12 @@ LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { + // Masked reductions can't be folded until we can propagate the mask to the + // resulting operation. + auto maskableOp = cast(reductionOp.getOperation()); + if (maskableOp.isMasked()) + return failure(); + ArrayRef shape = reductionOp.getSourceVectorType().getShape(); for (const auto &dim : enumerate(shape)) { if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1) @@ -518,6 +524,12 @@ LogicalResult matchAndRewrite(ReductionOp reductionOp, PatternRewriter &rewriter) const override { + // Masked reductions can't be folded until we can propagate the mask to the + // resulting operation. + auto maskableOp = cast(reductionOp.getOperation()); + if (maskableOp.isMasked()) + return failure(); + if (reductionOp.getVectorType().getDimSize(0) != 1) return failure(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1371,6 +1371,20 @@ // ----- +// Masked reduction can't be folded. + +// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions +func.func @masked_vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, + %acc: vector<5x4x20xf32>, + %mask: vector<5x1x4x1x20xi1>) -> vector<5x4x20xf32> { +// CHECK: vector.mask %{{.*}} { vector.multi_reduction + %0 = vector.mask %mask { vector.multi_reduction , %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } : + vector<5x1x4x1x20xi1> -> vector<5x4x20xf32> + return %0 : vector<5x4x20xf32> +} + +// ----- + // CHECK-LABEL: func @vector_multi_reduction_unit_dimensions_fail( // CHECK-SAME: %[[SRC:.+]]: vector<5x1x4x1x20xf32>, %[[ACCUM:.+]]: vector<5x1x20xf32> func.func @vector_multi_reduction_unit_dimensions_fail(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x1x20xf32>) -> vector<5x1x20xf32> { @@ -1921,6 +1935,18 @@ // ----- +// CHECK-LABEL: func @masked_reduce_one_element_vector_addf +// CHECK: vector.mask %{{.*}} { vector.reduction +func.func @masked_reduce_one_element_vector_addf(%a: vector<1xf32>, + %b: f32, + %mask: vector<1xi1>) -> f32 { + %s = vector.mask %mask { vector.reduction , %a, %b : vector<1xf32> into f32 } + : vector<1xi1> -> f32 + return %s : f32 +} + +// ----- + // CHECK-LABEL: func @reduce_one_element_vector_mulf // CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32) // CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>