diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -203,6 +203,20 @@ return attr.cast().getValue() == IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Vector Masking Utilities +//===----------------------------------------------------------------------===// + +/// Create the vector.yield-ended region of a vector.mask op with `maskableOp` +/// as masked operation. +void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp); + +/// Creates a vector.mask operation around a maskable operation. Returns the +/// vector.mask operation if the mask provided is valid. Otherwise, returns the +/// maskable operation itself. +Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp, + Value mask); + } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -340,6 +340,7 @@ PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]>, Arguments<(ins Vector_CombiningKindAttr:$kind, @@ -2338,16 +2339,13 @@ let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "Value":$mask, - CArg<"function_ref", - "buildTerminatedBody">:$maskRegion)>, - OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, - CArg<"function_ref", - "buildTerminatedBody">:$maskRegion)>, - OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, - "Value":$passthru, - CArg<"function_ref", - "buildTerminatedBody">:$maskRegion)> + OpBuilder<(ins "Value":$mask, "Operation *":$maskableOp, + CArg<"function_ref">:$maskRegion)>, + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Operation *":$maskableOp, + CArg<"function_ref">:$maskRegion)>, + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Value":$passthru, + "Operation *":$maskableOp, + CArg<"function_ref">:$maskRegion)> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -292,25 +292,8 @@ // Wrap the operation with a new `vector.mask` and update D-U chain. assert(opToMask && "Expected a valid operation to mask"); - auto opResults = opToMask->getResultTypes(); - auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) { - Block *insBlock = builder.getInsertionBlock(); - // Create a block, put an op in that block. Look for a utility. - // Maybe in conversion pattern rewriter. Way to avoid splice. - // Set insertion point. - insBlock->getOperations().splice( - insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask); - builder.create(loc, opToMask->getResults()); - }; - // TODO: Allow multiple results in vector.mask. - auto maskOp = - opResults.empty() - ? rewriter.create(opToMask->getLoc(), mask, - createRegionMask) - : rewriter.create(opToMask->getLoc(), - opToMask->getResultTypes().front(), - mask, createRegionMask); - + auto maskOp = cast( + mlir::vector::maskOperation(rewriter, opToMask, mask)); Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back(); for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults())) @@ -440,17 +423,16 @@ /// initial value.buildMultiDimReduce // Note: this is a true builder that notifies the OpBuilder listener. // TODO: Consider moving as a static helper on the ReduceOp. -static Operation *buildMultiDimReduce(OpBuilder &b, - Operation *reduceOp, Value valueToReduce, - Value acc, - const SmallVector &reductionMask) { +static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, + Value valueToReduce, Value acc, + ArrayRef dimsToMask) { auto maybeKind = getCombinerOpKind(reduceOp); assert(maybeKind && "Failed precondition: could not get reduction kind"); return b.create( - reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind); + reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind); } -static SmallVector getReductionMask(LinalgOp linalgOp) { +static SmallVector getDimsToReduce(LinalgOp linalgOp) { return llvm::to_vector( llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator)); } @@ -701,8 +683,8 @@ if (!reduceType || (outputType && reduceType.getShape() == outputType.getShape())) return nullptr; - SmallVector reductionMask = getReductionMask(linalgOp); - return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask); + SmallVector dimsToMask = getDimsToReduce(linalgOp); + return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask); } /// Generic vectorization for a single operation `op`, given already vectorized @@ -972,11 +954,8 @@ } static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) { - // TODO: Masking only supports dynamic generic ops without reductions for now. - if (!isElementwise(op) && - llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) { - return itType != utils::IteratorType::parallel; - })) + // TODO: Masking only supports dynamic generic ops for now. + if (!isa(op)) return failure(); // TODO: 0-d vectors are not supported yet. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -342,6 +342,13 @@ return success(); } +/// Returns the mask type expected by this operation. +Type MultiDimReductionOp::getExpectedMaskType() { + auto vecType = getSourceVectorType(); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + namespace { // Only unit dimensions that are being reduced are folded. If the dimension is // unit, but not reduced, it is not folded, thereby keeping the output type the @@ -5276,7 +5283,8 @@ void MaskOp::build( OpBuilder &builder, OperationState &result, Value mask, - function_ref maskRegionBuilder) { + Operation *maskableOp, + function_ref maskRegionBuilder) { assert(maskRegionBuilder && "builder callback for 'maskRegion' must be present"); @@ -5284,21 +5292,22 @@ OpBuilder::InsertionGuard guard(builder); Region *maskRegion = result.addRegion(); builder.createBlock(maskRegion); - maskRegionBuilder(builder, result.location); + maskRegionBuilder(builder, maskableOp); } void MaskOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTypes, - Value mask, function_ref maskRegionBuilder) { - build(builder, result, resultTypes, mask, /*passthru=*/Value(), + Value mask, Operation *maskableOp, + function_ref maskRegionBuilder) { + build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp, maskRegionBuilder); } void MaskOp::build( - OpBuilder &builder, OperationState &result, TypeRange resultTypes, - Value mask, Value passthru, - function_ref maskRegionBuilder) { - build(builder, result, mask, maskRegionBuilder); + OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value mask, + Value passthru, Operation *maskableOp, + function_ref maskRegionBuilder) { + build(builder, result, mask, maskableOp, maskRegionBuilder); if (passthru) result.addOperands(passthru); result.addTypes(resultTypes); @@ -5738,6 +5747,34 @@ llvm_unreachable("unknown CombiningKind"); } +//===----------------------------------------------------------------------===// +// Vector Masking Utilities +//===----------------------------------------------------------------------===// + +/// Create the vector.yield-ended region of a vector.mask op with `maskableOp` +/// as masked operation. +void mlir::vector::createMaskOpRegion(OpBuilder &builder, + Operation *maskableOp) { + assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block"); + Block *insBlock = builder.getInsertionBlock(); + // Create a block and move the op to that block. + insBlock->getOperations().splice( + insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp); + builder.create(maskableOp->getLoc(), maskableOp->getResults()); +} + +/// Creates a vector.mask operation around a maskable operation. Returns the +/// vector.mask operation if the mask provided is valid. Otherwise, returns +/// the maskable operation itself. +Operation *mlir::vector::maskOperation(RewriterBase &rewriter, + Operation *maskableOp, Value mask) { + if (!mask) + return maskableOp; + return rewriter.create(maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, maskableOp, + createMaskOpRegion); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -12,9 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #define DEBUG_TYPE "vector-multi-reduction" @@ -40,6 +38,18 @@ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = + cast(multiReductionOp.getOperation()); + Operation *rootOp; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + } else { + rootOp = multiReductionOp; + } + auto src = multiReductionOp.getSource(); auto loc = multiReductionOp.getLoc(); auto srcRank = multiReductionOp.getSourceVectorType().getRank(); @@ -79,6 +89,15 @@ indices.append(reductionDims.begin(), reductionDims.end()); indices.append(parallelDims.begin(), parallelDims.end()); } + + // If masked, transpose the original mask. + Value transposedMask; + if (maskableOp.isMasked()) { + transposedMask = rewriter.create( + loc, maskableOp.getMaskingOp().getMask(), indices); + } + + // Transpose reduction source. auto transposeOp = rewriter.create(loc, src, indices); SmallVector reductionMask(srcRank, false); for (int i = 0; i < reductionSize; ++i) { @@ -87,9 +106,14 @@ else reductionMask[i] = true; } - rewriter.replaceOpWithNewOp( - multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(), - reductionMask, multiReductionOp.getKind()); + + Operation *newMultiRedOp = rewriter.create( + multiReductionOp.getLoc(), transposeOp.getResult(), + multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); + newMultiRedOp = + mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask); + + rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0)); return success(); } @@ -113,6 +137,18 @@ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = + cast(multiReductionOp.getOperation()); + Operation *rootOp; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + } else { + rootOp = multiReductionOp; + } + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); auto srcShape = multiReductionOp.getSourceVectorType().getShape(); auto loc = multiReductionOp.getLoc(); @@ -186,10 +222,22 @@ std::swap(mask.front(), mask.back()); std::swap(vectorShape.front(), vectorShape.back()); } + + Value newVectorMask; + if (maskableOp.isMasked()) { + Value vectorMask = maskableOp.getMaskingOp().getMask(); + auto maskCastedType = VectorType::get( + vectorShape, + vectorMask.getType().cast().getElementType()); + newVectorMask = + rewriter.create(loc, maskCastedType, vectorMask); + } + auto castedType = VectorType::get( vectorShape, multiReductionOp.getSourceVectorType().getElementType()); Value cast = rewriter.create( loc, castedType, multiReductionOp.getSource()); + Value acc = multiReductionOp.getAcc(); if (flattenedParallelDim) { auto accType = VectorType::get( @@ -197,24 +245,26 @@ multiReductionOp.getSourceVectorType().getElementType()); acc = rewriter.create(loc, accType, acc); } - // 5. Creates the flattened form of vector.multi_reduction with inner/outer + // 6. Creates the flattened form of vector.multi_reduction with inner/outer // most dim as reduction. - auto newOp = rewriter.create( + Operation *newMultiDimRedOp = rewriter.create( loc, cast, acc, mask, multiReductionOp.getKind()); + newMultiDimRedOp = + mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask); - // 6. If there are no parallel shapes, the result is a scalar. + // 7. If there are no parallel shapes, the result is a scalar. // TODO: support 0-d vectors when available. if (parallelShapes.empty()) { - rewriter.replaceOp(multiReductionOp, newOp.getDest()); + rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0)); return success(); } - // 7. Creates shape cast for the output n-D -> 2-D + // 8. Creates shape cast for the output n-D -> 2-D. VectorType outputCastedType = VectorType::get( parallelShapes, multiReductionOp.getSourceVectorType().getElementType()); rewriter.replaceOpWithNewOp( - multiReductionOp, outputCastedType, newOp.getDest()); + rootOp, outputCastedType, newMultiDimRedOp->getResult(0)); return success(); } @@ -230,6 +280,12 @@ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { + auto maskableOp = + cast(multiReductionOp.getOperation()); + if (maskableOp.isMasked()) + // TODO: Support masking. + return failure(); + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Rank-2 ["parallel", "reduce"] or bail. if (srcRank != 2) @@ -274,6 +330,18 @@ if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) return failure(); + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = + cast(multiReductionOp.getOperation()); + Operation *rootOp; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + } else { + rootOp = multiReductionOp; + } + auto loc = multiReductionOp.getLoc(); Value result = rewriter.create( loc, multiReductionOp.getDestType(), @@ -285,13 +353,22 @@ loc, multiReductionOp.getSource(), ArrayRef{i}); auto acc = rewriter.create( loc, multiReductionOp.getAcc(), ArrayRef{i}); - auto reducedValue = rewriter.create( + Operation *reductionOp = rewriter.create( loc, multiReductionOp.getKind(), v, acc); + + // If masked, slice the mask and mask the new reduction operation. + if (maskableOp.isMasked()) { + Value mask = rewriter.create( + loc, maskableOp.getMaskingOp().getMask(), ArrayRef{i}); + reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask); + } + result = rewriter.create( - loc, reducedValue, result, + loc, reductionOp->getResult(0), result, rewriter.create(loc, i)); } - rewriter.replaceOp(multiReductionOp, result); + + rewriter.replaceOp(rootOp, result); return success(); } }; @@ -307,6 +384,12 @@ LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { + auto maskableOp = + cast(multiReductionOp.getOperation()); + if (maskableOp.isMasked()) + // TODO: Support masking. + return failure(); + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Rank-1 or bail. if (srcRank != 1) diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1824,6 +1824,82 @@ // ----- +func.func @vectorize_dynamic_reduction(%arg0: tensor, + %arg1: tensor) -> tensor { + %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] } + ins(%arg0 : tensor) + outs(%arg1 : tensor) { + ^bb(%in: f32, %out: f32) : + %0 = arith.addf %in, %out : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.masked_vectorize %0 vector_sizes [4, 8] +} + +// CHECK-LABEL: @vectorize_dynamic_reduction( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x8xi1> +// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true]} : tensor, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32> +// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1> +// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction , %[[VAL_9]], %[[VAL_12]] [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32> +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor } : vector<4xi1> -> tensor +// CHECK: return %[[VAL_15]] : tensor +// CHECK: } + +// ----- + +func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor, + %arg1: tensor) -> tensor { + %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>], + iterator_types = ["reduction", "parallel", "parallel"] } + ins(%arg0 : tensor) + outs(%arg1 : tensor) { + ^bb(%in: f32, %out: f32) : + %0 = arith.addf %in, %out : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.masked_vectorize %0 vector_sizes [4, 8, 16] +} + +// CHECK-LABEL: @vectorize_dynamic_transpose_reduction( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]], %[[VAL_7]] : vector<4x8x16xi1> +// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true, true]} : tensor, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32> +// CHECK: %[[VAL_13:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_5]] : vector<16x8xi1> +// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32> +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_10]] { vector.multi_reduction , %[[VAL_11]], %[[VAL_14]] [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32> +// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_13]] { vector.transfer_write %[[VAL_15]], %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<8x16xf32>, tensor } : vector<16x8xi1> -> tensor + +// ----- + // This is a regression test. This IR cannot be vectorized, but // structured.vectorize should nevertheless succeed. @@ -1892,4 +1968,3 @@ // CHECK-LABEL: @wrong_reduction_detection // CHECK: vector.broadcast // CHECK: vector.transfer_write - diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s +// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns -split-input-file | FileCheck %s func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> @@ -19,6 +19,8 @@ // CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] +// ----- + func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 { %0 = vector.multi_reduction , %arg0, %acc [0, 1] : vector<2x4xf32> to f32 return %0 : f32 @@ -31,6 +33,8 @@ // CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32> // CHECK: return %[[RES]] +// ----- + func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> { %0 = vector.multi_reduction , %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> @@ -72,6 +76,7 @@ // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32> // CHECK: return %[[RESULT]] +// ----- func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> { %0 = vector.multi_reduction , %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> @@ -85,6 +90,8 @@ // CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> // CHECK: return %[[RESULT]] +// ----- + func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> { %0 = vector.multi_reduction , %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32> return %0 : vector<2x4xf32> @@ -135,3 +142,95 @@ // CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32> // CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32> // CHECK: return %[[RESHAPED_VEC]] + +// ----- + +func.func @vectorize_dynamic_reduction(%arg0: tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %dim_0 = tensor.dim %arg0, %c1 : tensor + %c0_1 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.create_mask %dim, %dim_0 : vector<4x8xi1> + %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32> + %cst_2 = arith.constant 0.000000e+00 : f32 + %2 = vector.create_mask %dim : vector<4xi1> + %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> + %4 = vector.mask %0 { vector.multi_reduction , %1, %3 [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32> + %c0_3 = arith.constant 0 : index + %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<4xf32>, tensor } : vector<4xi1> -> tensor + return %5 : tensor +} + +// Verify that the original 2-D mask is sliced and propagated properly to the +// vector.reduction instances. + +// CHECK-LABEL: func.func @vectorize_dynamic_reduction +// CHECK: %[[VAL_8:.*]] = tensor.dim +// CHECK: %[[VAL_9:.*]] = tensor.dim +// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]] : vector<4x8xi1> + +// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_10]][0] : vector<4x8xi1> +// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 +// CHECK: %[[VAL_18:.*]] = vector.insertelement + +// CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_10]][1] : vector<4x8xi1> +// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_21]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 +// CHECK: %[[VAL_23:.*]] = vector.insertelement + +// CHECK: %[[VAL_26:.*]] = vector.extract %[[VAL_10]][2] : vector<4x8xi1> +// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_26]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 +// CHECK: %[[VAL_28:.*]] = vector.insertelement + +// CHECK: %[[VAL_31:.*]] = vector.extract %[[VAL_10]][3] : vector<4x8xi1> +// CHECK: %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 +// CHECK: %[[VAL_33:.*]] = vector.insertelement + +// ----- + +func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %dim_0 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %dim_1 = tensor.dim %arg0, %c2 : tensor + %c0_2 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.create_mask %dim, %dim_0, %dim_1 : vector<4x8x16xi1> + %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %cst {in_bounds = [true, true, true]} : tensor, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32> + %cst_3 = arith.constant 0.000000e+00 : f32 + %2 = vector.create_mask %dim_1, %dim_0 : vector<16x8xi1> + %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : tensor, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32> + %4 = vector.mask %0 { vector.multi_reduction , %1, %3 [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32> + %c0_4 = arith.constant 0 : index + %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_4, %c0_4] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x16xf32>, tensor } : vector<16x8xi1> -> tensor + return %5 : tensor +} + +// CHECK-LABEL: func.func @vectorize_dynamic_transpose_reduction +// CHECK: %[[VAL_6:.*]] = tensor.dim +// CHECK: %[[VAL_7:.*]] = tensor.dim +// CHECK: %[[VAL_8:.*]] = tensor.dim +// CHECK: %[[VAL_135:.*]] = vector.create_mask %{{.*}}, %{{.*}}, %{{.*}} : vector<4x8x16xi1> +// CHECK: %[[VAL_139:.*]] = vector.transpose %[[VAL_135]], [1, 2, 0] : vector<4x8x16xi1> to vector<8x16x4xi1> + +// Just checking a few instances to make sure the vector mask is properly propagated: + +// CHECK: %[[VAL_143:.*]] = vector.extract %[[VAL_139]][0, 0] : vector<8x16x4xi1> +// CHECK: %[[VAL_144:.*]] = vector.mask %[[VAL_143]] { vector.reduction +// CHECK: %[[VAL_145:.*]] = vector.insertelement %[[VAL_144]] + +// CHECK: %[[VAL_148:.*]] = vector.extract %[[VAL_139]][0, 1] : vector<8x16x4xi1> +// CHECK: %[[VAL_149:.*]] = vector.mask %[[VAL_148]] { vector.reduction +// CHECK: %[[VAL_150:.*]] = vector.insertelement %[[VAL_149]] + +// CHECK: %[[VAL_153:.*]] = vector.extract %[[VAL_139]][0, 2] : vector<8x16x4xi1> +// CHECK: %[[VAL_154:.*]] = vector.mask %[[VAL_153]] { vector.reduction +// CHECK: %[[VAL_155:.*]] = vector.insertelement %[[VAL_154]] + +// CHECK: %[[VAL_158:.*]] = vector.extract %[[VAL_139]][0, 3] : vector<8x16x4xi1> +// CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction +// CHECK: %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]] +