diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1846,7 +1846,7 @@ } def Vector_GatherOp : - Vector_Op<"gather">, + Vector_Op<"gather", [DeclareOpInterfaceMethods]>, Arguments<(ins Arg:$base, Variadic:$indices, VectorOf<[AnyInteger, Index]>:$index_vec, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -416,8 +416,7 @@ vector::BroadcastableToResult::Success) return value; Location loc = b.getInsertionPoint()->getLoc(); - return b.createOrFold(loc, targetVectorType, - value); + return b.createOrFold(loc, targetVectorType, value); } /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This @@ -532,14 +531,16 @@ /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult -vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) { +static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, + VectorizationState &state, + Operation *op, + LinalgOp linalgOp) { IndexOp indexOp = dyn_cast(op); if (!indexOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = indexOp.getLoc(); // Compute the static loop sizes of the index op. - auto targetShape = linalgOp.computeStaticLoopSizes(); + auto targetShape = llvm::to_vector(state.getCanonicalVecShape()); // Compute a one-dimensional index vector for the index op dimension. SmallVector constantSeq = llvm::to_vector<16>(llvm::seq(0, targetShape[indexOp.getDim()])); @@ -597,32 +598,33 @@ /// /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 -static Value -calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp, - const IRMapping &bvm, - const SmallVectorImpl &targetShape) { +static Value calculateGatherOffset(RewriterBase &rewriter, + tensor::ExtractOp extractOp, + const IRMapping &bvm, + const ArrayRef targetShape) { // The vector of indices for GatherOp should be shaped as the output vector - auto indexVecType = VectorType::get(targetShape, b.getIndexType()); + auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType()); auto loc = extractOp.getLoc(); - Value offset = b.create( - loc, indexVecType, bvm.lookup(extractOp.getIndices()[0])); + Value offset = broadcastIfNeeded( + rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape()); const size_t numIndices = extractOp.getIndices().size(); for (size_t i = 1; i < numIndices; i++) { auto dimSize = broadcastIfNeeded( - b, - b.create( + rewriter, + rewriter.create( loc, extractOp.getTensor().getType().cast().getDimSize(i)), indexVecType.getShape()); - offset = b.create(loc, offset, dimSize); + offset = rewriter.create(loc, offset, dimSize); - auto extractOpIndex = broadcastIfNeeded( - b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape()); + auto extractOpIndex = + broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]), + indexVecType.getShape()); - offset = b.create(loc, extractOpIndex, offset); + offset = rewriter.create(loc, extractOpIndex, offset); } return offset; @@ -632,17 +634,16 @@ /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, - Operation *op, - LinalgOp linalgOp, - const IRMapping &bvm) { +static VectorizationResult +vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, + Operation *op, LinalgOp linalgOp, const IRMapping &bvm) { tensor::ExtractOp extractOp = dyn_cast(op); if (!extractOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = extractOp.getLoc(); // Compute the static loop sizes of the extract op. - auto targetShape = linalgOp.computeStaticLoopSizes(); + auto targetShape = state.getCanonicalVecShape(); auto resultType = VectorType::get(targetShape, extractOp.getResult().getType()); @@ -662,9 +663,10 @@ Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); // Generate the gather load - auto gatherOp = rewriter.create( + Operation *gatherOp = rewriter.create( loc, resultType, extractOp.getTensor(), baseIndices, offset, maskConstantOp, passThruConstantOp); + gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; } @@ -904,14 +906,14 @@ // 4b. Register CustomVectorizationHook for indexOp. CustomVectorizationHook vectorizeIndex = [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { - return vectorizeLinalgIndex(rewriter, op, linalgOp); + return vectorizeLinalgIndex(rewriter, state, op, linalgOp); }; hooks.push_back(vectorizeIndex); // 4c. Register CustomVectorizationHook for extractOp. CustomVectorizationHook vectorizeExtract = [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { - return vectorizeTensorExtract(rewriter, op, linalgOp, bvm); + return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm); }; hooks.push_back(vectorizeExtract); @@ -1007,8 +1009,10 @@ return failure(); if (linalgOp.hasDynamicShape() && - failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) + failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) { + LDBG("Dynamically-shaped op failed vectorization pre-conditions\n"); return failure(); + } SmallVector customPreconditions; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4597,6 +4597,16 @@ return success(); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type GatherOp::getExpectedMaskType() { + auto vecType = this->getIndexVectorType(); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + namespace { class GatherFolder final : public OpRewritePattern { public: diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" @@ -109,6 +110,29 @@ } }; +/// Lowers a masked `vector.gather` operation. +struct MaskedGatherOpPattern : public MaskOpRewritePattern { +public: + using MaskOpRewritePattern::MaskOpRewritePattern; + + LogicalResult + matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + Value passthru = maskingOp.hasPassthru() + ? maskingOp.getPassthru() + : rewriter.create( + gatherOp.getLoc(), + rewriter.getZeroAttr(gatherOp.getVectorType())); + + // Replace the `vector.mask` operation. + rewriter.replaceOpWithNewOp( + maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), + gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), + passthru); + return success(); + } +}; + struct LowerVectorMaskPass : public vector::impl::LowerVectorMaskPassBase { using Base::Base; @@ -136,8 +160,8 @@ /// not its nested `MaskableOpInterface`. void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext()); } std::unique_ptr mlir::vector::createLowerVectorMaskPass() { diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir --- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir +++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir @@ -48,3 +48,32 @@ // CHECK: return %[[VAL_4]] : tensor // CHECK: } +// ----- + +func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %c3 = arith.constant 3 : index + %0 = vector.create_mask %c3 : vector<4xi1> + %1 = vector.mask %0 { vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} : tensor<3xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32> + %cst_0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> + %cst_1 = arith.constant dense : vector<4xi1> + %cst_2 = arith.constant dense<0.000000e+00> : vector<4xf32> + %c0_3 = arith.constant 0 : index + %2 = vector.mask %0 { vector.gather %arg0[%c0_3] [%cst_0], %cst_1, %cst_2 : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> } : vector<4xi1> -> vector<4xf32> + %c0_4 = arith.constant 0 : index + %3 = vector.mask %0 { vector.transfer_write %2, %arg1[%c0_4] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> } : vector<4xi1> -> tensor<3xf32> + return %3 : tensor<3xf32> +} + +// CHECK-LABEL: func.func @vector_gather( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_6:.*]] = vector.create_mask %[[VAL_5]] : vector<4xi1> +// CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> +// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> +