diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -40,6 +40,18 @@ struct BitmaskEnumStorage; } // namespace detail +/// Return whether `srcType` can be broadcast to `dstVectorType` under the +/// semantics of the `vector.broadcast` op. +enum class BroadcastableToResult { + Success = 0, + SourceRankHigher = 1, + DimensionMismatch = 2, + SourceTypeNotAVector = 3 +}; +BroadcastableToResult +isBroadcastableTo(Type srcType, VectorType dstVectorType, + std::pair *mismatchingDims = nullptr); + /// Collect a set of vector-to-vector canonicalization patterns. void populateVectorToVectorCanonicalizationPatterns( RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -147,24 +147,20 @@ return getKindForOp(combinerOps[0]); } -/// If `value` of assumed VectorType has a shape different than `shape`, try to -/// build and return a new vector.broadcast to `shape`. -/// Otherwise, just return `value`. -// TODO: this is best effort atm and there is currently no guarantee of -// correctness for the broadcast semantics. +/// Broadcast `value` to a vector of `shape` if possible. Return value +/// otherwise. static Value broadcastIfNeeded(OpBuilder &b, Value value, ArrayRef shape) { - unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(), - [](int64_t val) { return val > 1; }); - auto vecType = value.getType().dyn_cast(); - if (shape.empty() || - (vecType != nullptr && - (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne))) + // If no shape to broadcast to, just return `value`. + if (shape.empty()) + return value; + VectorType targetVectorType = + VectorType::get(shape, getElementTypeOrSelf(value)); + if (vector::isBroadcastableTo(value.getType(), targetVectorType) != + vector::BroadcastableToResult::Success) return value; - auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() - : value.getType()); - return b.create(b.getInsertionPoint()->getLoc(), - newVecType, value); + Location loc = b.getInsertionPoint()->getLoc(); + return b.createOrFold(loc, targetVectorType, value); } /// If value of assumed VectorType has a shape different than `shape`, build and @@ -688,7 +684,8 @@ // by TransferReadOp, but TransferReadOp supports only constant padding. auto padValue = padOp.getConstantPaddingValue(); if (!padValue) { - if (!sourceType.hasStaticShape()) return failure(); + if (!sourceType.hasStaticShape()) + return failure(); // Create dummy padding value. auto elemType = sourceType.getElementType(); padValue = rewriter.create(padOp.getLoc(), elemType, @@ -733,14 +730,14 @@ // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire // tensor, write directly to the FillOp's operand. - if (llvm::equal(vecShape, resultType.getShape()) - && llvm::all_of(writeInBounds, [](bool b) { return b; })) + if (llvm::equal(vecShape, resultType.getShape()) && + llvm::all_of(writeInBounds, [](bool b) { return b; })) if (auto fill = dest.getDefiningOp()) dest = fill.output(); // Generate TransferWriteOp. - auto writeIndices = ofrToIndexValues( - rewriter, padOp.getLoc(), padOp.getMixedLowPad()); + auto writeIndices = + ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad()); rewriter.replaceOpWithNewOp( padOp, read, dest, writeIndices, writeInBounds); @@ -764,9 +761,9 @@ return success(changed); } - protected: - virtual LogicalResult rewriteUser( - PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0; +protected: + virtual LogicalResult rewriteUser(PatternRewriter &rewriter, + PadTensorOp padOp, OpTy op) const = 0; }; /// Rewrite use of PadTensorOp result in TransferReadOp. E.g.: @@ -790,18 +787,21 @@ /// - Single, scalar padding value. struct PadTensorOpVectorizationWithTransferReadPattern : public VectorizePadTensorOpUserPattern { - using VectorizePadTensorOpUserPattern - ::VectorizePadTensorOpUserPattern; + using VectorizePadTensorOpUserPattern< + vector::TransferReadOp>::VectorizePadTensorOpUserPattern; LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, vector::TransferReadOp xferOp) const override { // Low padding must be static 0. - if (!padOp.hasZeroLowPad()) return failure(); + if (!padOp.hasZeroLowPad()) + return failure(); // Pad value must be a constant. auto padValue = padOp.getConstantPaddingValue(); - if (!padValue) return failure(); + if (!padValue) + return failure(); // Padding value of existing `xferOp` is unused. - if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure(); + if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) + return failure(); rewriter.updateRootInPlace(xferOp, [&]() { SmallVector inBounds(xferOp.getVectorType().getRank(), false); @@ -847,24 +847,30 @@ /// - Single, scalar padding value. struct PadTensorOpVectorizationWithTransferWritePattern : public VectorizePadTensorOpUserPattern { - using VectorizePadTensorOpUserPattern - ::VectorizePadTensorOpUserPattern; + using VectorizePadTensorOpUserPattern< + vector::TransferWriteOp>::VectorizePadTensorOpUserPattern; LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, vector::TransferWriteOp xferOp) const override { // Low padding must be static 0. - if (!padOp.hasZeroLowPad()) return failure(); + if (!padOp.hasZeroLowPad()) + return failure(); // Pad value must be a constant. auto padValue = padOp.getConstantPaddingValue(); - if (!padValue) return failure(); + if (!padValue) + return failure(); // TransferWriteOp result must be directly consumed by an ExtractSliceOp. - if (!xferOp->hasOneUse()) return failure(); + if (!xferOp->hasOneUse()) + return failure(); auto trimPadding = dyn_cast(*xferOp->user_begin()); - if (!trimPadding) return failure(); + if (!trimPadding) + return failure(); // Only static zero offsets supported when trimming padding. - if (!trimPadding.hasZeroOffset()) return failure(); + if (!trimPadding.hasZeroOffset()) + return failure(); // trimPadding must remove the amount of padding that was added earlier. - if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure(); + if (!hasSameTensorSize(padOp.source(), trimPadding)) + return failure(); // Insert the new TransferWriteOp at position of the old TransferWriteOp. rewriter.setInsertionPoint(xferOp); @@ -894,14 +900,17 @@ // If the input to PadTensorOp is a CastOp, try with with both CastOp result // and CastOp operand. if (auto castOp = beforePadding.getDefiningOp()) - if (hasSameTensorSize(castOp.source(), afterTrimming)) return true; + if (hasSameTensorSize(castOp.source(), afterTrimming)) + return true; auto t1 = beforePadding.getType().dyn_cast(); auto t2 = afterTrimming.getType().dyn_cast(); // Only RankedTensorType supported. - if (!t1 || !t2) return false; + if (!t1 || !t2) + return false; // Rank of both values must be the same. - if (t1.getRank() != t2.getRank()) return false; + if (t1.getRank() != t2.getRank()) + return false; // All static dimensions must be the same. Mixed cases (e.g., dimension // static in `t1` but dynamic in `t2`) are not supported. @@ -913,7 +922,8 @@ } // Nothing more to check if all dimensions are static. - if (t1.getNumDynamicDims() == 0) return true; + if (t1.getNumDynamicDims() == 0) + return true; // All dynamic sizes must be the same. The only supported case at the moment // is when `beforePadding` is an ExtractSliceOp (or a cast thereof). @@ -925,29 +935,33 @@ assert(static_cast(t1.getRank()) == beforeSlice.getMixedSizes().size()); - assert(static_cast(t2.getRank()) - == afterTrimming.getMixedSizes().size()); + assert(static_cast(t2.getRank()) == + afterTrimming.getMixedSizes().size()); for (unsigned i = 0; i < t1.getRank(); ++i) { // Skip static dimensions. - if (!t1.isDynamicDim(i)) continue; + if (!t1.isDynamicDim(i)) + continue; auto size1 = beforeSlice.getMixedSizes()[i]; auto size2 = afterTrimming.getMixedSizes()[i]; // Case 1: Same value or same constant int. - if (isEqualConstantIntOrValue(size1, size2)) continue; + if (isEqualConstantIntOrValue(size1, size2)) + continue; // Other cases: Take a deeper look at defining ops of values. auto v1 = size1.dyn_cast(); auto v2 = size2.dyn_cast(); - if (!v1 || !v2) return false; + if (!v1 || !v2) + return false; // Case 2: Both values are identical AffineMinOps. (Should not happen if // CSE is run.) auto minOp1 = v1.getDefiningOp(); auto minOp2 = v2.getDefiningOp(); - if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() - && minOp1.operands() == minOp2.operands()) continue; + if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && + minOp1.operands() == minOp2.operands()) + continue; // Add additional cases as needed. } @@ -987,9 +1001,11 @@ LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp, tensor::InsertSliceOp insertOp) const override { // Low padding must be static 0. - if (!padOp.hasZeroLowPad()) return failure(); + if (!padOp.hasZeroLowPad()) + return failure(); // Only unit stride supported. - if (!insertOp.hasUnitStride()) return failure(); + if (!insertOp.hasUnitStride()) + return failure(); // Pad value must be a constant. auto padValue = padOp.getConstantPaddingValue(); if (!padValue) @@ -1038,8 +1054,8 @@ void mlir::linalg::populatePadTensorOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { - patterns.add( - patterns.getContext(), baseBenefit); + patterns.add(patterns.getContext(), + baseBenefit); // Try these specialized patterns first before resorting to the generic one. patterns.add(); - VectorType dstVectorType = op.getVectorType(); - // Scalar to vector broadcast is always valid. A vector - // to vector broadcast needs some additional checking. - if (srcVectorType) { - int64_t srcRank = srcVectorType.getRank(); - int64_t dstRank = dstVectorType.getRank(); - if (srcRank > dstRank) - return op.emitOpError("source rank higher than destination rank"); - // Source has an exact match or singleton value for all trailing dimensions - // (all leading dimensions are simply duplicated). - int64_t lead = dstRank - srcRank; - for (int64_t r = 0; r < srcRank; ++r) { - int64_t srcDim = srcVectorType.getDimSize(r); - int64_t dstDim = dstVectorType.getDimSize(lead + r); - if (srcDim != 1 && srcDim != dstDim) - return op.emitOpError("dimension mismatch (") - << srcDim << " vs. " << dstDim << ")"; +BroadcastableToResult +mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, + std::pair *mismatchingDims) { + // Broadcast scalar to vector of the same element type. + if (srcType.isIntOrIndexOrFloat() && dstVectorType && + getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) + return BroadcastableToResult::Success; + // From now on, only vectors broadcast. + VectorType srcVectorType = srcType.dyn_cast(); + if (!srcVectorType) + return BroadcastableToResult::SourceTypeNotAVector; + + int64_t srcRank = srcVectorType.getRank(); + int64_t dstRank = dstVectorType.getRank(); + if (srcRank > dstRank) + return BroadcastableToResult::SourceRankHigher; + // Source has an exact match or singleton value for all trailing dimensions + // (all leading dimensions are simply duplicated). + int64_t lead = dstRank - srcRank; + for (int64_t r = 0; r < srcRank; ++r) { + int64_t srcDim = srcVectorType.getDimSize(r); + int64_t dstDim = dstVectorType.getDimSize(lead + r); + if (srcDim != 1 && srcDim != dstDim) { + if (mismatchingDims) { + mismatchingDims->first = srcDim; + mismatchingDims->second = dstDim; + } + return BroadcastableToResult::DimensionMismatch; } } - return success(); + + return BroadcastableToResult::Success; +} + +static LogicalResult verify(BroadcastOp op) { + std::pair mismatchingDims; + BroadcastableToResult res = isBroadcastableTo( + op.getSourceType(), op.getVectorType(), &mismatchingDims); + if (res == BroadcastableToResult::Success) + return success(); + if (res == BroadcastableToResult::SourceRankHigher) + return op.emitOpError("source rank higher than destination rank"); + if (res == BroadcastableToResult::DimensionMismatch) + return op.emitOpError("dimension mismatch (") + << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; + if (res == BroadcastableToResult::SourceTypeNotAVector) + return op.emitOpError("source type is not a vector"); + llvm_unreachable("unexpected vector.broadcast op error"); } OpFoldResult BroadcastOp::fold(ArrayRef operands) { + if (getSourceType() == getVectorType()) + return source(); if (!operands[0]) return {}; auto vectorType = getVectorType(); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -30,6 +30,13 @@ // ----- +func @broadcast_unknown(%arg0: memref<4x8xf32>) { + // expected-error@+1 {{'vector.broadcast' op source type is not a vector}} + %1 = vector.broadcast %arg0 : memref<4x8xf32> to vector<1x8xf32> +} + +// ----- + func @shuffle_elt_type_mismatch(%arg0: vector<2xf32>, %arg1: vector<2xi32>) { // expected-error@+1 {{'vector.shuffle' op failed to verify that second operand v2 and result have same element type}} %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xi32> diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -493,7 +493,6 @@ func @cast_away_broadcast_leading_one_dims( %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) -> (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) { - // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32> // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32> // CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>