diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -40,6 +40,18 @@ struct BitmaskEnumStorage; } // namespace detail +/// Return whether `srcType` can be broadcast to `dstVectorType` under the +/// semantics of the `vector.broadcast` op. +enum class BroadcastableToResult { + Success = 0, + SourceRankHigher = 1, + DimensionMismatch = 2, + Unknown = 3 +}; +BroadcastableToResult +isBroadcastableTo(Type srcType, VectorType dstVectorType, + std::pair *mismatchingDims = nullptr); + /// Collect a set of vector-to-vector canonicalization patterns. void populateVectorToVectorCanonicalizationPatterns( RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -147,24 +147,25 @@ return getKindForOp(combinerOps[0]); } -/// If `value` of assumed VectorType has a shape different than `shape`, try to -/// build and return a new vector.broadcast to `shape`. -/// Otherwise, just return `value`. -// TODO: this is best effort atm and there is currently no guarantee of -// correctness for the broadcast semantics. +/// Broadcast `value` to a vector of `shape` if possible. Return value +/// otherwise. static Value broadcastIfNeeded(OpBuilder &b, Value value, ArrayRef shape) { - unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(), - [](int64_t val) { return val > 1; }); - auto vecType = value.getType().dyn_cast(); - if (shape.empty() || - (vecType != nullptr && - (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne))) + // If no shape to broadcast to, just return `value`. + if (shape.empty()) + return value; + VectorType targetVectorType = + VectorType::get(shape, getElementTypeOrSelf(value)); + if (vector::isBroadcastableTo(value.getType(), targetVectorType) != + vector::BroadcastableToResult::Success) return value; - auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() - : value.getType()); - return b.create(b.getInsertionPoint()->getLoc(), - newVecType, value); + Location loc = b.getInsertionPoint()->getLoc(); + if (auto vecType = value.getType().dyn_cast()) { + // If the shape already fits, return `value`. + if (vecType.getShape() == shape) + return value; + } + return b.create(loc, targetVectorType, value); } /// If value of assumed VectorType has a shape different than `shape`, build and @@ -688,7 +689,8 @@ // by TransferReadOp, but TransferReadOp supports only constant padding. auto padValue = padOp.getConstantPaddingValue(); if (!padValue) { - if (!sourceType.hasStaticShape()) return failure(); + if (!sourceType.hasStaticShape()) + return failure(); // Create dummy padding value. auto elemType = sourceType.getElementType(); padValue = rewriter.create(padOp.getLoc(), elemType, @@ -733,14 +735,14 @@ // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire // tensor, write directly to the FillOp's operand. - if (llvm::equal(vecShape, resultType.getShape()) - && llvm::all_of(writeInBounds, [](bool b) { return b; })) + if (llvm::equal(vecShape, resultType.getShape()) && + llvm::all_of(writeInBounds, [](bool b) { return b; })) if (auto fill = dest.getDefiningOp()) dest = fill.output(); // Generate TransferWriteOp. - auto writeIndices = ofrToIndexValues( - rewriter, padOp.getLoc(), padOp.getMixedLowPad()); + auto writeIndices = + ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad()); rewriter.replaceOpWithNewOp( padOp, read, dest, writeIndices, writeInBounds); @@ -764,9 +766,9 @@ return success(changed); } - protected: - virtual LogicalResult rewriteUser( - PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0; +protected: + virtual LogicalResult rewriteUser(PatternRewriter &rewriter, + PadTensorOp padOp, OpTy op) const = 0; }; /// Rewrite use of PadTensorOp result in TransferReadOp. E.g.: @@ -790,18 +792,21 @@ /// - Single, scalar padding value. struct PadTensorOpVectorizationWithTransferReadPattern : public VectorizePadTensorOpUserPattern { - using VectorizePadTensorOpUserPattern - ::VectorizePadTensorOpUserPattern; + using VectorizePadTensorOpUserPattern< + vector::TransferReadOp>::VectorizePadTensorOpUserPattern; LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, vector::TransferReadOp xferOp) const override { // Low padding must be static 0. - if (!padOp.hasZeroLowPad()) return failure(); + if (!padOp.hasZeroLowPad()) + return failure(); // Pad value must be a constant. auto padValue = padOp.getConstantPaddingValue(); - if (!padValue) return failure(); + if (!padValue) + return failure(); // Padding value of existing `xferOp` is unused. - if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure(); + if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) + return failure(); rewriter.updateRootInPlace(xferOp, [&]() { SmallVector inBounds(xferOp.getVectorType().getRank(), false); @@ -847,24 +852,30 @@ /// - Single, scalar padding value. struct PadTensorOpVectorizationWithTransferWritePattern : public VectorizePadTensorOpUserPattern { - using VectorizePadTensorOpUserPattern - ::VectorizePadTensorOpUserPattern; + using VectorizePadTensorOpUserPattern< + vector::TransferWriteOp>::VectorizePadTensorOpUserPattern; LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, vector::TransferWriteOp xferOp) const override { // Low padding must be static 0. - if (!padOp.hasZeroLowPad()) return failure(); + if (!padOp.hasZeroLowPad()) + return failure(); // Pad value must be a constant. auto padValue = padOp.getConstantPaddingValue(); - if (!padValue) return failure(); + if (!padValue) + return failure(); // TransferWriteOp result must be directly consumed by an ExtractSliceOp. - if (!xferOp->hasOneUse()) return failure(); + if (!xferOp->hasOneUse()) + return failure(); auto trimPadding = dyn_cast(*xferOp->user_begin()); - if (!trimPadding) return failure(); + if (!trimPadding) + return failure(); // Only static zero offsets supported when trimming padding. - if (!trimPadding.hasZeroOffset()) return failure(); + if (!trimPadding.hasZeroOffset()) + return failure(); // trimPadding must remove the amount of padding that was added earlier. - if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure(); + if (!hasSameTensorSize(padOp.source(), trimPadding)) + return failure(); // Insert the new TransferWriteOp at position of the old TransferWriteOp. rewriter.setInsertionPoint(xferOp); @@ -894,14 +905,17 @@ // If the input to PadTensorOp is a CastOp, try with with both CastOp result // and CastOp operand. if (auto castOp = beforePadding.getDefiningOp()) - if (hasSameTensorSize(castOp.source(), afterTrimming)) return true; + if (hasSameTensorSize(castOp.source(), afterTrimming)) + return true; auto t1 = beforePadding.getType().dyn_cast(); auto t2 = afterTrimming.getType().dyn_cast(); // Only RankedTensorType supported. - if (!t1 || !t2) return false; + if (!t1 || !t2) + return false; // Rank of both values must be the same. - if (t1.getRank() != t2.getRank()) return false; + if (t1.getRank() != t2.getRank()) + return false; // All static dimensions must be the same. Mixed cases (e.g., dimension // static in `t1` but dynamic in `t2`) are not supported. @@ -913,7 +927,8 @@ } // Nothing more to check if all dimensions are static. - if (t1.getNumDynamicDims() == 0) return true; + if (t1.getNumDynamicDims() == 0) + return true; // All dynamic sizes must be the same. The only supported case at the moment // is when `beforePadding` is an ExtractSliceOp (or a cast thereof). @@ -925,29 +940,33 @@ assert(static_cast(t1.getRank()) == beforeSlice.getMixedSizes().size()); - assert(static_cast(t2.getRank()) - == afterTrimming.getMixedSizes().size()); + assert(static_cast(t2.getRank()) == + afterTrimming.getMixedSizes().size()); for (unsigned i = 0; i < t1.getRank(); ++i) { // Skip static dimensions. - if (!t1.isDynamicDim(i)) continue; + if (!t1.isDynamicDim(i)) + continue; auto size1 = beforeSlice.getMixedSizes()[i]; auto size2 = afterTrimming.getMixedSizes()[i]; // Case 1: Same value or same constant int. - if (isEqualConstantIntOrValue(size1, size2)) continue; + if (isEqualConstantIntOrValue(size1, size2)) + continue; // Other cases: Take a deeper look at defining ops of values. auto v1 = size1.dyn_cast(); auto v2 = size2.dyn_cast(); - if (!v1 || !v2) return false; + if (!v1 || !v2) + return false; // Case 2: Both values are identical AffineMinOps. (Should not happen if // CSE is run.) auto minOp1 = v1.getDefiningOp(); auto minOp2 = v2.getDefiningOp(); - if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() - && minOp1.operands() == minOp2.operands()) continue; + if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && + minOp1.operands() == minOp2.operands()) + continue; // Add additional cases as needed. } @@ -987,9 +1006,11 @@ LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, tensor::InsertSliceOp insertOp) const override { // Low padding must be static 0. - if (!padOp.hasZeroLowPad()) return failure(); + if (!padOp.hasZeroLowPad()) + return failure(); // Only unit stride supported. - if (!insertOp.hasUnitStride()) return failure(); + if (!insertOp.hasUnitStride()) + return failure(); // Pad value must be a constant. auto padValue = padOp.getConstantPaddingValue(); if (!padValue) @@ -1038,8 +1059,8 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { - patterns.add( - patterns.getContext(), baseBenefit); + patterns.add(patterns.getContext(), + baseBenefit); // Try these specialized patterns first before resorting to the generic one. patterns.add AffineMap calculateImplicitMap(MapOp op) { +template +AffineMap calculateImplicitMap(MapOp op) { SmallVector perm; // Check which dimension have a multiplicity greater than 1 and associated // them to the IDs in order. @@ -1320,28 +1321,52 @@ // BroadcastOp //===----------------------------------------------------------------------===// -static LogicalResult verify(BroadcastOp op) { - VectorType srcVectorType = op.getSourceType().dyn_cast(); - VectorType dstVectorType = op.getVectorType(); - // Scalar to vector broadcast is always valid. A vector - // to vector broadcast needs some additional checking. - if (srcVectorType) { - int64_t srcRank = srcVectorType.getRank(); - int64_t dstRank = dstVectorType.getRank(); - if (srcRank > dstRank) - return op.emitOpError("source rank higher than destination rank"); - // Source has an exact match or singleton value for all trailing dimensions - // (all leading dimensions are simply duplicated). - int64_t lead = dstRank - srcRank; - for (int64_t r = 0; r < srcRank; ++r) { - int64_t srcDim = srcVectorType.getDimSize(r); - int64_t dstDim = dstVectorType.getDimSize(lead + r); - if (srcDim != 1 && srcDim != dstDim) - return op.emitOpError("dimension mismatch (") - << srcDim << " vs. " << dstDim << ")"; +BroadcastableToResult +mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, + std::pair *mismatchingDims) { + // Scalars broadcast to vectors of the same elemental type. + if (srcType.isIntOrIndexOrFloat() && dstVectorType && + getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) + return BroadcastableToResult::Success; + // From now on, only vectors broadcast. + VectorType srcVectorType = srcType.dyn_cast(); + if (!srcVectorType) + return BroadcastableToResult::Unknown; + + int64_t srcRank = srcVectorType.getRank(); + int64_t dstRank = dstVectorType.getRank(); + if (srcRank > dstRank) + return BroadcastableToResult::SourceRankHigher; + // Source has an exact match or singleton value for all trailing dimensions + // (all leading dimensions are simply duplicated). + int64_t lead = dstRank - srcRank; + for (int64_t r = 0; r < srcRank; ++r) { + int64_t srcDim = srcVectorType.getDimSize(r); + int64_t dstDim = dstVectorType.getDimSize(lead + r); + if (srcDim != 1 && srcDim != dstDim) { + if (mismatchingDims) { + mismatchingDims->first = srcDim; + mismatchingDims->second = dstDim; + } + return BroadcastableToResult::DimensionMismatch; } } - return success(); + + return BroadcastableToResult::Success; +} + +static LogicalResult verify(BroadcastOp op) { + std::pair mismatchingDims; + BroadcastableToResult res = isBroadcastableTo( + op.getSourceType(), op.getVectorType(), &mismatchingDims); + if (res == BroadcastableToResult::Success) + return success(); + if (res == BroadcastableToResult::SourceRankHigher) + return op.emitOpError("source rank higher than destination rank"); + if (res == BroadcastableToResult::DimensionMismatch) + return op.emitOpError("dimension mismatch (") + << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; + return op.emitOpError("unknown isBroadcastableTo behavior"); } OpFoldResult BroadcastOp::fold(ArrayRef operands) {