diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -385,17 +385,25 @@ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { - auto maskableOp = - cast(multiReductionOp.getOperation()); - if (maskableOp.isMasked()) - // TODO: Support masking. - return failure(); - auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Rank-1 or bail. if (srcRank != 1) return failure(); + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = + cast(multiReductionOp.getOperation()); + Operation *rootOp; + Value mask; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + mask = maskableOp.getMaskingOp().getMask(); + } else { + rootOp = multiReductionOp; + } + auto loc = multiReductionOp.getLoc(); auto srcVectorType = multiReductionOp.getSourceVectorType(); auto srcShape = srcVectorType.getShape(); @@ -408,16 +416,27 @@ // If the unique dim is reduced and we insert a parallel in front, we need a // {false, true} mask. - SmallVector mask{false, true}; + SmallVector reductionMask{false, true}; /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) Value cast = rewriter.create( loc, castedType, multiReductionOp.getSource()); Value castAcc = rewriter.create( loc, accType, multiReductionOp.getAcc()); - Value reduced = rewriter.create( - loc, cast, castAcc, mask, multiReductionOp.getKind()); - rewriter.replaceOpWithNewOp(multiReductionOp, reduced, + Value castMask; + if (maskableOp.isMasked()) { + auto maskType = mask.getType().cast(); + auto castMaskType = + VectorType::get(ArrayRef{1, maskType.getShape().back()}, + maskType.getElementType()); + castMask = rewriter.create(loc, castMaskType, mask); + } + + Operation *newOp = rewriter.create( + loc, cast, castAcc, reductionMask, multiReductionOp.getKind()); + newOp = vector::maskOperation(rewriter, newOp, castMask); + + rewriter.replaceOpWithNewOp(rootOp, newOp->getResult(0), ArrayRef{0}); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -189,6 +189,26 @@ // ----- +func.func @vectorize_1d_dynamic_reduction(%arg0: tensor) -> f32 { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %c0_1 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.create_mask %dim : vector<8xi1> + %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1], %cst {in_bounds = [true]} : tensor, vector<8xf32> } : vector<8xi1> -> vector<8xf32> + %4 = vector.mask %0 { vector.multi_reduction , %1, %cst [0] : vector<8xf32> to f32 } : vector<8xi1> -> f32 + return %4 : f32 +} + +// Verify that a 1-D vector.multi_reduction is transformed into a vector.reduction. +// This transform expands 1-D vectors into 2-D. + +// CHECK-LABEL: func.func @vectorize_1d_dynamic_reduction( +// CHECK: %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1> +// CHECK: %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 + +// ----- + func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor, %arg1: tensor) -> tensor { %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor