diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -184,11 +184,14 @@ ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has a dynamic value. - After buffer-allocation, the "extract_slice" op is expected to lower into a - "subview" op. + After buffer allocation, the "extract_slice" op is expected to lower into a + memref.subview op. An extract_slice operation may additionally reduce the rank of the resulting tensor by removing dimensions that are statically known to be of size 1. + This rank-reduction behavior is not required by the op semantics: this + flexibility allows to progressively drop unit dimensions while lowering + between different flavors of ops on that operate on tensors. Example: @@ -196,8 +199,8 @@ // Rank-reducing extract_slice. %1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] : tensor<8x16x4xf32> to tensor<16x4xf32> - %3 = tensor.extract_slice %2[3, 4, 2][1, 6, 3][1, 1, 1] : - tensor<8x16x4xf32> to tensor<6x3xf32> + %3 = tensor.extract_slice %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] : + tensor<8x16x4xf32> to tensor<1x?xf32> ``` }]; @@ -257,24 +260,28 @@ /// An extract_slice result type can be fully inferred from the source type /// and the static representation of offsets, sizes and strides. Special /// sentinels encode the dynamic case. - static Type inferResultType(RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); - static Type inferResultType(RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); - static Type inferRankReducedResultType(unsigned resultRank, - RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); - static Type inferRankReducedResultType(unsigned resultRank, - RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); + static RankedTensorType inferResultType( + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static RankedTensorType inferResultType( + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static RankedTensorType inferRankReducedResultType( + unsigned resultRank, + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static RankedTensorType inferRankReducedResultType( + unsigned resultRank, + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. @@ -469,8 +476,27 @@ ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has a dynamic value. - After buffer-allocation, the "insert_slice" op is expected to become an - in-place buffer update. + After buffer allocation, the "insert_slice" op is expected to lower into a + memref.subview op. + + An insert_slice operation may additionally specify insertion into a tensor + of higher rank than the source tensor, along dimensions that are statically + known to be of size 1. + This rank-altering behavior is not required by the op semantics: this + flexibility allows to progressively drop unit dimensions while lowering + between different flavors of ops on that operate on tensors. + The rank-altering behavior of tensor.insert_slice matches the rank-reducing + behavior of tensor.extract_slice. + + Example: + + ``` + // Rank-reducing extract_slice. + %1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] : + tensor<16x4xf32> into tensor<8x16x4xf32> + %3 = tensor.insert_slice %tt into %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] : + tensor<1x?xf32> into tensor<8x16x4xf32> + ``` }]; let arguments = (ins @@ -493,8 +519,6 @@ attr-dict `:` type($source) `into` type($dest) }]; - let verifier = ?; - let builders = [ // Build a InsertSliceOp with mixed static and dynamic entries. OpBuilder<(ins "Value":$source, "Value":$dest, diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -21,9 +21,10 @@ namespace mlir { -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// Helper function to dispatch an OpFoldResult into `staticVec` if: +/// a) it is an IntegerAttr +/// In other cases, the OpFoldResult is dispached to the `dynamicVec`. +/// In such dynamic cases, a copy of the `sentinel` value is also pushed to /// `staticVec`. This is useful to extract mixed static and dynamic entries that /// come from an AttrSizedOperandSegments trait. void dispatchIndexOpFoldResult(OpFoldResult ofr, @@ -31,11 +32,8 @@ SmallVectorImpl &staticVec, int64_t sentinel); -/// Helper function to dispatch multiple OpFoldResults into either the -/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs). -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. +/// Helper function to dispatch multiple OpFoldResults according to the behavior +/// of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single OpFoldResult. void dispatchIndexOpFoldResults(ArrayRef ofrs, SmallVectorImpl &dynamicVec, SmallVectorImpl &staticVec, diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -369,6 +369,25 @@ computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape); +/// Enum that captures information related to verifier error conditions on +/// slice insert/extract type of ops. +enum class SliceVerificationResult { + Success, + RankTooLarge, + SizeMismatch, + ElemTypeMismatch, + // Error codes to ops with a memory space and a layout annotation. + MemSpaceMismatch, + LayoutMismatch +}; + +/// Check if `originalType` can be rank reduced to `candidateReducedType` type +/// by dropping some dimensions with static size `1`. +/// Return `SliceVerificationResult::Success` on success or an appropriate error +/// code. +SliceVerificationResult isRankReducedType(ShapedType originalType, + ShapedType candidateReducedType); + //===----------------------------------------------------------------------===// // Deferred Method Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2248,8 +2248,8 @@ Location loc = op.getLoc(); int axis = op.axis(); - Value axisValue = - rewriter.create(loc, rewriter.getIndexAttr(axis)); + Value axisValue = rewriter.createOrFold( + loc, rewriter.getIndexAttr(axis)); int rank = resultType.getRank(); SmallVector offsets, sizes, strides; sizes.reserve(rank); @@ -2257,31 +2257,41 @@ offsets.resize(rank, rewriter.create(loc, 0)); for (int i = 0; i < rank; ++i) { - sizes.push_back( - rewriter.create(loc, adaptor.getOperands()[0], i)); + sizes.push_back(rewriter.createOrFold( + loc, adaptor.getOperands()[0], i)); } Value resultDimSize = sizes[axis]; for (auto arg : adaptor.getOperands().drop_front()) { - auto size = rewriter.create(loc, arg, axisValue); - resultDimSize = rewriter.create(loc, resultDimSize, size); + auto size = rewriter.createOrFold(loc, arg, axisValue); + resultDimSize = + rewriter.createOrFold(loc, resultDimSize, size); } sizes[axis] = resultDimSize; Value init = rewriter.create( loc, resultType.getShape(), resultType.getElementType()); - Value zeroVal = rewriter.create( + Value zeroVal = rewriter.createOrFold( loc, rewriter.getZeroAttr(resultType.getElementType())); Value result = rewriter.create(loc, zeroVal, init).getResult(0); + auto toOpFoldResult = [](Value v) -> OpFoldResult { + auto op = v.getDefiningOp(); + if (!op) + return v; + return op.getValue(); + }; for (auto arg : adaptor.getOperands()) { - sizes[axis] = rewriter.create(loc, arg, axisValue); - result = rewriter.create(loc, arg, result, offsets, - sizes, strides); + sizes[axis] = rewriter.createOrFold(loc, arg, axisValue); + result = rewriter.createOrFold( + loc, arg, result, + llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), + llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), + llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); offsets[axis] = - rewriter.create(loc, offsets[axis], sizes[axis]); + rewriter.createOrFold(loc, offsets[axis], sizes[axis]); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -835,16 +835,14 @@ //===----------------------------------------------------------------------===// // InitTensorOp //===----------------------------------------------------------------------===// + void InitTensorOp::build(OpBuilder &b, OperationState &result, ArrayRef sizes, Type elementType, ArrayRef attrs) { - unsigned rank = sizes.size(); SmallVector dynamicSizes; SmallVector staticSizes; - for (unsigned i = 0; i < rank; ++i) { - dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes, - ShapedType::kDynamicSize); - } + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, + ShapedType::kDynamicSize); auto resultType = RankedTensorType ::get(staticSizes, elementType); build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); result.addAttributes(attrs); @@ -1127,19 +1125,16 @@ ArrayRef attrs) { assert(resultType.isa()); auto sourceType = source.getType().cast(); - unsigned rank = sourceType.getRank(); SmallVector dynamicLow, dynamicHigh; SmallVector staticLow, staticHigh; - for (unsigned i = 0; i < rank; ++i) { - // staticLow and staticHigh have full information of the padding config. - // This will grow staticLow and staticHigh with 1 value. If the config is - // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 - // value as well. - dispatchIndexOpFoldResult(low[i], dynamicLow, staticLow, - ShapedType::kDynamicSize); - dispatchIndexOpFoldResult(high[i], dynamicHigh, staticHigh, - ShapedType::kDynamicSize); - } + // staticLow and staticHigh have full information of the padding config. + // This will grow staticLow and staticHigh with 1 value. If the config is + // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 + // value as well. + dispatchIndexOpFoldResults(low, dynamicLow, staticLow, + ShapedType::kDynamicSize); + dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh, + ShapedType::kDynamicSize); if (!resultType) { resultType = PadTensorOp::inferResultType(sourceType, staticLow, staticHigh); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -504,11 +504,13 @@ return numOccurences; } -/// Given the type of the un-rank reduced subview result type and the -/// rank-reduced result type, computes the dropped dimensions. This accounts for -/// cases where there are multiple unit-dims, but only a subset of those are -/// dropped. For MemRefTypes these can be disambiguated using the strides. If a -/// dimension is dropped the stride must be dropped too. +/// Given the `originalType` and a `candidateReducedType` whose shape is assumed +/// to be a subset of `originalType` with some `1` entries erased, return the +/// set of indices that specifies which of the entries of `originalShape` are +/// dropped to obtain `reducedShape`. +/// This accounts for cases where there are multiple unit-dims, but only a +/// subset of those are dropped. For MemRefTypes these can be disambiguated +/// using the strides. If a dimension is dropped the stride must be dropped too. static llvm::Optional> computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef sizes) { @@ -1548,8 +1550,7 @@ dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, - staticSizes, staticStrides) - .cast(); + staticSizes, staticStrides); } Type SubViewOp::inferRankReducedResultType( @@ -1706,88 +1707,58 @@ /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { return source(); } -enum SubViewVerificationResult { - Success, - RankTooLarge, - SizeMismatch, - ElemTypeMismatch, - MemSpaceMismatch, - AffineMapMismatch -}; - /// Checks if `original` Type type can be rank reduced to `reduced` type. /// This function is slight variant of `is subsequence` algorithm where /// not matching dimension must be 1. -static SubViewVerificationResult -isRankReducedType(Type originalType, Type candidateReducedType, - ArrayRef sizes, std::string *errMsg = nullptr) { - if (originalType == candidateReducedType) - return SubViewVerificationResult::Success; - if (!originalType.isa()) - return SubViewVerificationResult::Success; - if (originalType.isa() && !candidateReducedType.isa()) - return SubViewVerificationResult::Success; - - ShapedType originalShapedType = originalType.cast(); - ShapedType candidateReducedShapedType = - candidateReducedType.cast(); - - // Rank and size logic is valid for all ShapedTypes. - ArrayRef originalShape = originalShapedType.getShape(); - ArrayRef candidateReducedShape = - candidateReducedShapedType.getShape(); - unsigned originalRank = originalShape.size(), - candidateReducedRank = candidateReducedShape.size(); - if (candidateReducedRank > originalRank) - return SubViewVerificationResult::RankTooLarge; +static SliceVerificationResult +isRankReducedMemRefType(MemRefType originalType, + MemRefType candidatecandidateReducedType, + ArrayRef sizes) { + auto partialRes = + isRankReducedType(originalType, candidatecandidateReducedType); + if (partialRes != SliceVerificationResult::Success) + return partialRes; MemRefType original = originalType.cast(); - MemRefType candidateReduced = candidateReducedType.cast(); + MemRefType candidateReduced = + candidatecandidateReducedType.cast(); auto optionalUnusedDimsMask = computeMemRefRankReductionMask(original, candidateReduced, sizes); // Sizes cannot be matched in case empty vector is returned. if (!optionalUnusedDimsMask.hasValue()) - return SubViewVerificationResult::SizeMismatch; + return SliceVerificationResult::LayoutMismatch; - if (originalShapedType.getElementType() != - candidateReducedShapedType.getElementType()) - return SubViewVerificationResult::ElemTypeMismatch; - - // Strided layout logic is relevant for MemRefType only. if (original.getMemorySpace() != candidateReduced.getMemorySpace()) - return SubViewVerificationResult::MemSpaceMismatch; - return SubViewVerificationResult::Success; + return SliceVerificationResult::MemSpaceMismatch; + + return SliceVerificationResult::Success; } template -static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, - OpTy op, Type expectedType, - StringRef errMsg = "") { +static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, + OpTy op, Type expectedType) { auto memrefType = expectedType.cast(); switch (result) { - case SubViewVerificationResult::Success: + case SliceVerificationResult::Success: return success(); - case SubViewVerificationResult::RankTooLarge: + case SliceVerificationResult::RankTooLarge: return op.emitError("expected result rank to be smaller or equal to ") - << "the source rank. " << errMsg; - case SubViewVerificationResult::SizeMismatch: + << "the source rank. "; + case SliceVerificationResult::SizeMismatch: return op.emitError("expected result type to be ") << expectedType - << " or a rank-reduced version. (mismatch of result sizes) " - << errMsg; - case SubViewVerificationResult::ElemTypeMismatch: + << " or a rank-reduced version. (mismatch of result sizes) "; + case SliceVerificationResult::ElemTypeMismatch: return op.emitError("expected result element type to be ") - << memrefType.getElementType() << errMsg; - case SubViewVerificationResult::MemSpaceMismatch: - return op.emitError("expected result and source memory spaces to match.") - << errMsg; - case SubViewVerificationResult::AffineMapMismatch: + << memrefType.getElementType(); + case SliceVerificationResult::MemSpaceMismatch: + return op.emitError("expected result and source memory spaces to match."); + case SliceVerificationResult::LayoutMismatch: return op.emitError("expected result type to be ") << expectedType - << " or a rank-reduced version. (mismatch of result affine map) " - << errMsg; + << " or a rank-reduced version. (mismatch of result layout) "; } llvm_unreachable("unexpected subview verification result"); } @@ -1813,10 +1784,9 @@ extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); - std::string errMsg; - auto result = - isRankReducedType(expectedType, subViewType, op.getMixedSizes(), &errMsg); - return produceSubViewErrorMsg(result, op, expectedType, errMsg); + auto result = isRankReducedMemRefType(expectedType.cast(), + subViewType, op.getMixedSizes()); + return produceSubViewErrorMsg(result, op, expectedType); } raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -655,10 +656,11 @@ /// An extract_slice op result type can be fully inferred from the source type /// and the static representation of offsets, sizes and strides. Special /// sentinels encode the dynamic case. -Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, - ArrayRef leadingStaticOffsets, - ArrayRef leadingStaticSizes, - ArrayRef leadingStaticStrides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { // An extract_slice op may specify only a leading subset of offset/sizes/ // strides in which case we complete with offset=0, sizes from memref type and // strides=1. @@ -673,11 +675,11 @@ sourceRankedTensorType.getElementType()); } -Type ExtractSliceOp::inferResultType( - RankedTensorType sourceRankedTensorType, - ArrayRef leadingStaticOffsets, - ArrayRef leadingStaticSizes, - ArrayRef leadingStaticStrides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, @@ -693,7 +695,7 @@ /// An extract_slice op result type can be fully inferred from the source type /// and the static representation of offsets, sizes and strides. Special /// sentinels encode the dynamic case. -Type ExtractSliceOp::inferRankReducedResultType( +RankedTensorType ExtractSliceOp::inferRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, ArrayRef leadingStaticOffsets, ArrayRef leadingStaticSizes, @@ -717,7 +719,7 @@ return inferredType; } -Type ExtractSliceOp::inferRankReducedResultType( +RankedTensorType ExtractSliceOp::inferRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, ArrayRef leadingStaticOffsets, ArrayRef leadingStaticSizes, @@ -746,10 +748,12 @@ SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, + ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, + ShapedType::kDynamicStrideOrOffset); auto sourceRankedTensorType = source.getType().cast(); // Structuring implementation this way avoids duplication between builders. @@ -797,89 +801,35 @@ build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); } -enum SliceVerificationResult { - Success, - RankTooLarge, - SizeMismatch, - ElemTypeMismatch, -}; - -/// Checks if `original` Type type can be rank reduced to `reduced` type. -/// This function is slight variant of `is subsequence` algorithm where -/// not matching dimension must be 1. -static SliceVerificationResult -isRankReducedType(Type originalType, Type candidateReducedType, - std::string *errMsg = nullptr) { - if (originalType == candidateReducedType) - return SliceVerificationResult::Success; - if (!originalType.isa()) - return SliceVerificationResult::Success; - if (originalType.isa() && - !candidateReducedType.isa()) - return SliceVerificationResult::Success; - - ShapedType originalShapedType = originalType.cast(); - ShapedType candidateReducedShapedType = - candidateReducedType.cast(); - - // Rank and size logic is valid for all ShapedTypes. - ArrayRef originalShape = originalShapedType.getShape(); - ArrayRef candidateReducedShape = - candidateReducedShapedType.getShape(); - unsigned originalRank = originalShape.size(), - candidateReducedRank = candidateReducedShape.size(); - if (candidateReducedRank > originalRank) - return SliceVerificationResult::RankTooLarge; - - auto optionalUnusedDimsMask = - computeRankReductionMask(originalShape, candidateReducedShape); - - // Sizes cannot be matched in case empty vector is returned. - if (!optionalUnusedDimsMask.hasValue()) - return SliceVerificationResult::SizeMismatch; - - if (originalShapedType.getElementType() != - candidateReducedShapedType.getElementType()) - return SliceVerificationResult::ElemTypeMismatch; - - // We are done for the tensor case. - if (originalType.isa()) - return SliceVerificationResult::Success; - - return SliceVerificationResult::Success; -} - template static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, - OpTy op, Type expectedType, - StringRef errMsg = "") { + OpTy op, Type expectedType) { auto memrefType = expectedType.cast(); switch (result) { case SliceVerificationResult::Success: return success(); case SliceVerificationResult::RankTooLarge: - return op.emitError("expected result rank to be smaller or equal to ") - << "the source rank. " << errMsg; + return op.emitError("expected rank to be smaller or equal to ") + << "the other rank. "; case SliceVerificationResult::SizeMismatch: - return op.emitError("expected result type to be ") - << expectedType - << " or a rank-reduced version. (mismatch of result sizes) " - << errMsg; + return op.emitError("expected type to be ") + << expectedType << " or a rank-reduced version. (size mismatch) "; case SliceVerificationResult::ElemTypeMismatch: - return op.emitError("expected result element type to be ") - << memrefType.getElementType() << errMsg; + return op.emitError("expected element type to be ") + << memrefType.getElementType(); + default: + llvm_unreachable("unexpected extract_slice op verification result"); } - llvm_unreachable("unexpected extract_slice op verification result"); } /// Verifier for ExtractSliceOp. static LogicalResult verify(ExtractSliceOp op) { // Verify result type against inferred type. - auto expectedType = ExtractSliceOp::inferResultType( - op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()), - extractFromI64ArrayAttr(op.static_sizes()), - extractFromI64ArrayAttr(op.static_strides())); - auto result = isRankReducedType(expectedType, op.getType()); + auto expectedType = + ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides()); + auto result = + isRankReducedType(expectedType.cast(), op.getType()); return produceSliceErrorMsg(result, op, expectedType); } @@ -1104,10 +1054,12 @@ SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, + ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, + ShapedType::kDynamicStrideOrOffset); build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes, dynamicStrides, b.getI64ArrayAttr(staticOffsets), @@ -1128,6 +1080,19 @@ build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +/// Verifier for InsertSliceOp. +static LogicalResult verify(InsertSliceOp op) { + // insert_slice is the inverse of extract_slice, use the same type inference. + auto expectedType = ExtractSliceOp::inferRankReducedResultType( + op.getSourceType().getRank(), op.getType(), + extractFromI64ArrayAttr(op.static_offsets()), + extractFromI64ArrayAttr(op.static_sizes()), + extractFromI64ArrayAttr(op.static_strides())); + auto result = + isRankReducedType(expectedType.cast(), op.getSourceType()); + return produceSliceErrorMsg(result, op, expectedType); +} + /// If we have two consecutive InsertSliceOp writing to the same slice, we /// can mutate the second InsertSliceOp's destination to the first one's. /// @@ -1202,9 +1167,16 @@ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - rewriter.replaceOpWithNewOp( - insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(), + auto sourceType = ExtractSliceOp::inferRankReducedResultType( + insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(), mixedOffsets, mixedSizes, mixedStrides); + Value toInsert = insertSliceOp.source(); + if (sourceType != insertSliceOp.getSourceType()) + toInsert = rewriter.create(insertSliceOp.getLoc(), + sourceType, toInsert); + rewriter.replaceOpWithNewOp( + insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes, + mixedStrides); return success(); } }; diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -13,22 +13,24 @@ namespace mlir { -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// Helper function to dispatch an OpFoldResult into `staticVec` if: +/// a) it is an IntegerAttr +/// In other cases, the OpFoldResult is dispached to the `dynamicVec`. +/// In such dynamic cases, a copy of the `sentinel` value is also pushed to /// `staticVec`. This is useful to extract mixed static and dynamic entries that /// come from an AttrSizedOperandSegments trait. void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl &dynamicVec, SmallVectorImpl &staticVec, int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); + auto v = ofr.dyn_cast(); + if (!v) { + APInt apInt = ofr.get().cast().getValue(); + staticVec.push_back(apInt.getSExtValue()); return; } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); + dynamicVec.push_back(v); + staticVec.push_back(sentinel); } void dispatchIndexOpFoldResults(ArrayRef ofrs, diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -571,7 +571,7 @@ llvm::SmallDenseSet unusedDims; unsigned reducedIdx = 0; for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { - // Greedily insert `originalIdx` if no match. + // Greedily insert `originalIdx` if match. if (reducedIdx < reducedRank && originalShape[originalIdx] == reducedShape[reducedIdx]) { reducedIdx++; @@ -590,6 +590,39 @@ return unusedDims; } +SliceVerificationResult +mlir::isRankReducedType(ShapedType originalType, + ShapedType candidateReducedType) { + if (originalType == candidateReducedType) + return SliceVerificationResult::Success; + + ShapedType originalShapedType = originalType.cast(); + ShapedType candidateReducedShapedType = + candidateReducedType.cast(); + + // Rank and size logic is valid for all ShapedTypes. + ArrayRef originalShape = originalShapedType.getShape(); + ArrayRef candidateReducedShape = + candidateReducedShapedType.getShape(); + unsigned originalRank = originalShape.size(), + candidateReducedRank = candidateReducedShape.size(); + if (candidateReducedRank > originalRank) + return SliceVerificationResult::RankTooLarge; + + auto optionalUnusedDimsMask = + computeRankReductionMask(originalShape, candidateReducedShape); + + // Sizes cannot be matched in case empty vector is returned. + if (!optionalUnusedDimsMask.hasValue()) + return SliceVerificationResult::SizeMismatch; + + if (originalShapedType.getElementType() != + candidateReducedShapedType.getElementType()) + return SliceVerificationResult::ElemTypeMismatch; + + return SliceVerificationResult::Success; +} + bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { // Empty attribute is allowed as default memory space. if (!memorySpace) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -820,38 +820,24 @@ // CHECK: [[STRIDE:%.+]] = arith.constant 1 // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index // CHECK: [[IDX0:%.+]] = arith.constant 0 : index - // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]] // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]] - // CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg1, [[AXIS]] - // CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM0]], [[ARG1_AXIS]] // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1] // CHECK: [[CST:%.+]] = arith.constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]]) - // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[AXIS]] - // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] - // CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM0]] - // CHECK: [[ARG1_DIM0:%.+]] = tensor.dim %arg1, [[AXIS]] - // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1] + // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]][5, 0] [6, 1] [1, 1] %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>) // CHECK: [[AXIS:%.+]] = arith.constant 1 // CHECK: [[STRIDE:%.+]] = arith.constant 1 // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index // CHECK: [[IDX0:%.+]] = arith.constant 0 : index - // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]] // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]] - // CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg0, [[AXIS]] - // CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM1]], [[ARG1_AXIS]] // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2] // CHECK: [[CST:%.+]] = arith.constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]]) - // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]] - // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] - // CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM1]] - // CHECK: [[ARG1_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]] - // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1] + // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]][0, 1] [5, 1] [1, 1] %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>) return } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -428,7 +428,9 @@ %A : tensor, %B : tensor {linalg.inplaceable = true}, %C : tensor {linalg.inplaceable = true}, - %idx : index) + %idx : index, + %sz1 : index, + %sz2 : index) -> (tensor, tensor, tensor) { %f0 = arith.constant 0.0 : f32 @@ -497,9 +499,9 @@ // CHECK-NEXT: tensor.insert_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"]} %sC = tensor.extract_slice %C[0, 0][%idx, %idx][1, 1] : tensor to tensor - %ssC = tensor.extract_slice %sC[0, 0][4, 4][1, 1] : tensor to tensor<4x4xf32> - %FC = linalg.fill(%f0, %ssC) : f32, tensor<4x4xf32> -> tensor<4x4xf32> - %rsC = tensor.insert_slice %FC into %sC[0, 0][12345, 67890][1, 1] : tensor<4x4xf32> into tensor + %ssC = tensor.extract_slice %sC[0, 0][%sz1, 4][1, 1] : tensor to tensor + %FC = linalg.fill(%f0, %ssC) : f32, tensor -> tensor + %rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor into tensor %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor into tensor return %rA, %rB, %rC: tensor, tensor, tensor diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -592,7 +592,7 @@ linalg.yield %1 : f32 } -> tensor<4xf32> - %sum_sub = tensor.insert_slice %acc into %o_[%j][%c4][1] + %sum_sub = tensor.insert_slice %acc into %o_[%j][4][1] : tensor<4xf32> into tensor<24xf32> linalg.yield %sum_sub : tensor<24xf32> } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -644,7 +644,7 @@ // ----- func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { - // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result sizes)}} + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result layout)}} %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref return } @@ -653,7 +653,7 @@ func @static_stride_to_dynamic_stride(%arg0 : memref, %arg1 : index, %arg2 : index) -> memref { - // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}} + // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result layout)}} %0 = memref.subview %arg0[0, 0, 0] [1, %arg1, %arg2] [1, 1, 1] : memref to memref return %0 : memref } diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir --- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir +++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir @@ -250,7 +250,7 @@ // CHECK: scf.for // CHECK: tensor.dim %[[t]] func @tensor_dim_of_iter_arg_insertslice(%t : tensor, - %t2 : tensor) -> index { + %t2 : tensor<10x10xf32>) -> index { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index @@ -258,9 +258,9 @@ -> (tensor, index) { %dim = tensor.dim %arg0, %c0 : tensor %2 = tensor.insert_slice %t2 into %arg0[0, 0] [10, 10] [1, 1] - : tensor into tensor + : tensor<10x10xf32> into tensor %3 = tensor.insert_slice %t2 into %2[1, 1] [10, 10] [1, 1] - : tensor into tensor + : tensor<10x10xf32> into tensor scf.yield %3, %dim : tensor, index } return %1 : index @@ -274,7 +274,7 @@ // CHECK: scf.for // CHECK: tensor.dim %[[t]] func @tensor_dim_of_iter_arg_nested_for(%t : tensor, - %t2 : tensor) -> index { + %t2 : tensor<10x10xf32>) -> index { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index @@ -284,7 +284,7 @@ -> (tensor, index) { %dim = tensor.dim %arg2, %c0 : tensor %4 = tensor.insert_slice %t2 into %arg2[0, 0] [10, 10] [1, 1] - : tensor into tensor + : tensor<10x10xf32> into tensor scf.yield %4, %dim : tensor, index } scf.yield %2, %3 : tensor, index @@ -292,6 +292,7 @@ return %1 : index } + // ----- // A test case that should not canonicalize because the loop is not shape diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -348,8 +348,10 @@ // CHECK-NOT: tensor.cast // CHECK: return %[[S]] : tensor<4x6x16x32xi8> func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { + %c0 = arith.constant 0: index %cast = tensor.cast %a : tensor<16x32xi8> to tensor - %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor into tensor<4x6x16x32xi8> + %sz = tensor.dim %cast, %c0: tensor + %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, %sz] [1, 1, 1] : tensor into tensor<4x6x16x32xi8> return %res : tensor<4x6x16x32xi8> } @@ -408,9 +410,10 @@ } // CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] +// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<4x?xf32> +// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]] // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] -// CHECK-SAME: : tensor into tensor +// CHECK-SAME: : tensor<4x?xf32> into tensor // CHEKC: return %[[RESULT]] // ----- @@ -450,7 +453,7 @@ ^bb0(%arg4: index, %arg5: index): tensor.yield %1 : i32 } : tensor - %3 = tensor.insert_slice %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor + %3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor return %3 : tensor } // CHECK-LABEL: func @insert_slice_propagate_dest_cast @@ -462,9 +465,6 @@ // ----- func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor) -> tensor<3x9xi32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index %c9 = arith.constant 9 : index %c3 = arith.constant 3 : index %2 = tensor.extract %arg1[] : tensor @@ -472,7 +472,7 @@ ^bb0(%arg2: index, %arg3: index): tensor.yield %2 : i32 } : tensor - %5 = tensor.insert_slice %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor + %5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor %6 = tensor.cast %5 : tensor to tensor<3x9xi32> return %6 : tensor<3x9xi32> } @@ -527,8 +527,9 @@ // CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor // CHECK: return %[[r]] func @insert_tensor_cast_on_insert_slice_src( - %arg0 : tensor, %arg1 : tensor) -> tensor { - %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1] + %arg0 : tensor, %arg1 : tensor, %sz0: index, %sz2: index) -> tensor { + %c64 = arith.constant 64: index + %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1] : tensor into tensor return %r : tensor } @@ -559,13 +560,3 @@ // CHECK: return %[[INSERT]] return %1 : tensor } - -// ----- - -// CHECK-LABEL: func @folding_incorrect_ir_triggers_infinite_loop -func @folding_incorrect_ir_triggers_infinite_loop( - %A : tensor<4x4xf32>, %C : tensor) -> tensor { - %rC = tensor.insert_slice %A into %C[0, 0][12345, 67890][1, 1] : - tensor<4x4xf32> into tensor - return %rC: tensor -} diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -149,8 +149,36 @@ // ----- -func @slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) { - // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}} +func @extract_slice_wrong_result_rank(%t: tensor, %idx : index) { + // expected-error @+1 {{expected rank to be smaller or equal to the other rank.}} + %0 = tensor.extract_slice %t[0][4][1] : tensor to tensor + + return +} + +// ----- + +func @extract_slice_wrong_result_rank(%t: tensor, %idx : index) { + // expected-error @+1 {{expected element type to be 'f32'}} + %0 = tensor.extract_slice %t[0][4][1] : tensor to tensor<4xi8> + + return +} + +// ----- + +func @extract_slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) { + // expected-error @+1 {{expected type to be 'tensor' or a rank-reduced version. (size mismatch)}} + %0 = tensor.extract_slice %t[0, 0, 0][%idx, 4, 4][1, 1, 1] + : tensor<8x16x4xf32> to tensor<4x4x4xf32> + + return +} + +// ----- + +func @extract_slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) { + // expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}} %0 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1] : tensor<8x16x4xf32> to tensor @@ -159,10 +187,38 @@ // ----- -func @slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) { - // expected-error @+1 {{expected result type to be 'tensor' or a rank-reduced version. (mismatch of result sizes)}} - %0 = tensor.extract_slice %t[0, 0, 0][%idx, 3, %idx][1, 1, 1] - : tensor<8x16x4xf32> to tensor<4x4x4xf32> +func @insert_slice_wrong_result_rank(%t1: tensor, %t2: tensor, %idx : index) { + // expected-error @+1 {{expected rank to be smaller or equal to the other rank.}} + %0 = tensor.insert_slice %t2 into %t1[0][4][1] : tensor into tensor + + return +} + +// ----- + +func @insert_slice_wrong_result_rank(%t1: tensor<4xi8>, %t2: tensor, %idx : index) { + // expected-error @+1 {{expected element type to be 'f32'}} + %0 = tensor.insert_slice %t1 into %t2[0][4][1] : tensor<4xi8> into tensor + + return +} + +// ----- + +func @insert_slice_wrong_static_type(%t1: tensor<4x4x4xf32>, %t2: tensor<8x16x4xf32>, %idx : index) { + // expected-error @+1 {{expected type to be 'tensor' or a rank-reduced version. (size mismatch)}} + %0 = tensor.insert_slice %t1 into %t2[0, 0, 0][%idx, 4, 4][1, 1, 1] + : tensor<4x4x4xf32> into tensor<8x16x4xf32> + + return +} + +// ----- + +func @insert_slice_wrong_dynamic_type(%t1: tensor, %t2: tensor<8x16x4xf32>, %idx : index) { + // expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}} + %0 = tensor.insert_slice %t1 into %t2[0, 2, 0][4, 4, 4][1, 1, 1] + : tensor into tensor<8x16x4xf32> return } diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -78,3 +78,60 @@ : (tensor, tensor) -> tensor<*xf32> return %new_unranked : tensor<*xf32> } + +// CHECK-LABEL: func @slice({{.*}}) { +func @slice(%t: tensor<8x16x4xf32>, %idx : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: tensor.extract_slice + // CHECK-SAME: tensor<8x16x4xf32> to tensor + %1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1] + : tensor<8x16x4xf32> to tensor + + // CHECK: tensor.extract_slice + // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32> + %2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1] + : tensor<8x16x4xf32> to tensor<4x4x4xf32> + + // CHECK: tensor.extract_slice + // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32> + %3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1] + : tensor<8x16x4xf32> to tensor<4x4xf32> + + return +} + +// CHECK-LABEL: func @insert_slice({{.*}}) { +func @insert_slice( + %t: tensor<8x16x4xf32>, + %td: tensor<8x?x4xf32>, + %t2: tensor<16x32x8xf32>, + %t3: tensor<4x4xf32>, + %idx : index, + %sz : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: tensor.insert_slice + // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32> + %1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][8, 16, 4][%c1, %c1, %c1] + : tensor<8x16x4xf32> into tensor<16x32x8xf32> + + // CHECK: tensor.insert_slice + // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32> + %2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][8, 16, 4][%c1, 1, %c1] + : tensor<8x16x4xf32> into tensor<16x32x8xf32> + + // CHECK: tensor.insert_slice + // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32> + %3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1] + : tensor<4x4xf32> into tensor<8x16x4xf32> + + // CHECK: tensor.insert_slice + // CHECK-SAME: tensor<8x?x4xf32> into tensor<8x16x4xf32> + %4 = tensor.insert_slice %td into %t[0, %idx, 0][8, %sz, 4][1, 1, 1] + : tensor<8x?x4xf32> into tensor<8x16x4xf32> + + return +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -486,53 +486,3 @@ memref.assume_alignment %0, 16 : memref<4x4xf16> return } - -// CHECK-LABEL: func @slice({{.*}}) { -func @slice(%t: tensor<8x16x4xf32>, %idx : index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - // CHECK: tensor.extract_slice - // CHECK-SAME: tensor<8x16x4xf32> to tensor - %1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1] - : tensor<8x16x4xf32> to tensor - - // CHECK: tensor.extract_slice - // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32> - %2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1] - : tensor<8x16x4xf32> to tensor<4x4x4xf32> - - // CHECK: tensor.extract_slice - // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32> - %3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1] - : tensor<8x16x4xf32> to tensor<4x4xf32> - - return -} - -// CHECK-LABEL: func @insert_slice({{.*}}) { -func @insert_slice( - %t: tensor<8x16x4xf32>, - %t2: tensor<16x32x8xf32>, - %t3: tensor<4x4xf32>, - %idx : index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - // CHECK: tensor.insert_slice - // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32> - %1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1] - : tensor<8x16x4xf32> into tensor<16x32x8xf32> - - // CHECK: tensor.insert_slice - // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32> - %2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1] - : tensor<8x16x4xf32> into tensor<16x32x8xf32> - - // CHECK: tensor.insert_slice - // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32> - %3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1] - : tensor<4x4xf32> into tensor<8x16x4xf32> - - return -}