diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2184,36 +2184,50 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) - return failure(); + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "rank exceeds maxTransferRank: " << write; + }); // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. if ( // pass-through for the 0-d corner case. !write.getPermutationMap().isMinorIdentity()) - return failure(); + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "permutation map is not minor identity: " << write; + }); auto memRefType = write.getShapedType().dyn_cast(); if (!memRefType) - return failure(); + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "not a memref type: " << write; + }); // Non-unit strides are handled by VectorToSCF. if (!vector::isLastMemrefDimUnitStride(memRefType)) - return failure(); + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "most minor stride is not 1: " << write; + }); // `vector.store` supports vector types as memref's elements only when the // type of the vector value being written is the same as the element type. auto memrefElTy = memRefType.getElementType(); if (memrefElTy.isa() && memrefElTy != write.getVectorType()) - return failure(); + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "elemental type mismatch: " << write; + }); // Otherwise, element types of the memref and the vector must match. if (!memrefElTy.isa() && memrefElTy != write.getVectorType().getElementType()) - return failure(); + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "elemental type mismatch: " << write; + }); // Out-of-bounds dims are handled by MaterializeTransferMask. if (write.hasOutOfBoundsDim()) - return failure(); + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "out of bounds dim: " << write; + }); if (write.getMask()) { rewriter.replaceOpWithNewOp( write, write.getSource(), write.getIndices(), write.getMask(),