diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -184,11 +184,14 @@ ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has a dynamic value. - After buffer-allocation, the "extract_slice" op is expected to lower into a - "subview" op. + After buffer allocation, the "extract_slice" op is expected to lower into a + memref.subview op. An extract_slice operation may additionally reduce the rank of the resulting tensor by removing dimensions that are statically known to be of size 1. + This rank-reduction behavior is not required by the op semantics: this + flexibility allows to progressively drop unit dimensions while lowering + between different flavors of ops on that operate on tensors. Example: @@ -196,8 +199,8 @@ // Rank-reducing extract_slice. %1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] : tensor<8x16x4xf32> to tensor<16x4xf32> - %3 = tensor.extract_slice %2[3, 4, 2][1, 6, 3][1, 1, 1] : - tensor<8x16x4xf32> to tensor<6x3xf32> + %3 = tensor.extract_slice %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] : + tensor<8x16x4xf32> to tensor<1x?xf32> ``` }]; @@ -257,24 +260,28 @@ /// An extract_slice result type can be fully inferred from the source type /// and the static representation of offsets, sizes and strides. Special /// sentinels encode the dynamic case. - static Type inferResultType(RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); - static Type inferResultType(RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); - static Type inferRankReducedResultType(unsigned resultRank, - RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); - static Type inferRankReducedResultType(unsigned resultRank, - RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides); + static RankedTensorType inferResultType( + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static RankedTensorType inferResultType( + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static RankedTensorType inferRankReducedResultType( + unsigned resultRank, + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); + static RankedTensorType inferRankReducedResultType( + unsigned resultRank, + RankedTensorType sourceRankedTensorType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. @@ -469,8 +476,27 @@ ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has a dynamic value. - After buffer-allocation, the "insert_slice" op is expected to become an - in-place buffer update. + After buffer allocation, the "insert_slice" op is expected to lower into a + memref.subview op. + + An insert_slice operation may additionally specify insertion into a tensor + of higher rank than the source tensor, along dimensions that are statically + known to be of size 1. + This rank-altering behavior is not required by the op semantics: this + flexibility allows to progressively drop unit dimensions while lowering + between different flavors of ops on that operate on tensors. + The rank-altering behavior of tensor.insert_slice matches the rank-reducing + behavior of tensor.extract_slice. + + Example: + + ``` + // Rank-reducing extract_slice. + %1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] : + tensor<16x4xf32> into tensor<8x16x4xf32> + %3 = tensor.insert_slice %tt into %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] : + tensor<1x?xf32> into tensor<8x16x4xf32> + ``` }]; let arguments = (ins @@ -493,8 +519,6 @@ attr-dict `:` type($source) `into` type($dest) }]; - let verifier = ?; - let builders = [ // Build a InsertSliceOp with mixed static and dynamic entries. OpBuilder<(ins "Value":$source, "Value":$dest, diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -21,25 +21,26 @@ namespace mlir { -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// Helper function to dispatch an OpFoldResult into `staticVec` if: +/// a) it is an IntegerAttr or +/// b) it is possible to extract an IntegerAttr by extractConstantInteger +/// In other cases, the OpFoldResult is dispached to the `dynamicVec`. +/// In such dynamic cases, a copy of the `sentinel` value is also pushed to /// `staticVec`. This is useful to extract mixed static and dynamic entries that /// come from an AttrSizedOperandSegments trait. -void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel); +void dispatchIndexOpFoldResult( + OpFoldResult ofr, SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + llvm::function_ref extractConstantInteger, + int64_t sentinel); -/// Helper function to dispatch multiple OpFoldResults into either the -/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs). -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -void dispatchIndexOpFoldResults(ArrayRef ofrs, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel); +/// Helper function to dispatch multiple OpFoldResults according to the behavior +/// of `dispatchIndexOpFoldResult(OpFoldResult ofr` for a single OpFoldResult. +void dispatchIndexOpFoldResults( + ArrayRef ofrs, SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + llvm::function_ref extractConstantInteger, + int64_t sentinel); /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. SmallVector extractFromI64ArrayAttr(Attribute attr); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -369,6 +369,22 @@ computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape); +/// Enum that captures information related to verifier error conditions on +/// slice insert/extract type of ops. +enum class SliceVerificationResult { + Success, + RankTooLarge, + SizeMismatch, + ElemTypeMismatch, + // Error codes to ops with a memory space and a layout annotation. + MemSpaceMismatch, + LayoutMismatch +}; + +SliceVerificationResult isRankReducedType(ShapedType originalType, + ShapedType candidateReducedType, + std::string *errMsg = nullptr); + //===----------------------------------------------------------------------===// // Deferred Method Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2248,8 +2248,8 @@ Location loc = op.getLoc(); int axis = op.axis(); - Value axisValue = - rewriter.create(loc, rewriter.getIndexAttr(axis)); + Value axisValue = rewriter.createOrFold( + loc, rewriter.getIndexAttr(axis)); int rank = resultType.getRank(); SmallVector offsets, sizes, strides; sizes.reserve(rank); @@ -2257,31 +2257,41 @@ offsets.resize(rank, rewriter.create(loc, 0)); for (int i = 0; i < rank; ++i) { - sizes.push_back( - rewriter.create(loc, adaptor.getOperands()[0], i)); + sizes.push_back(rewriter.createOrFold( + loc, adaptor.getOperands()[0], i)); } Value resultDimSize = sizes[axis]; for (auto arg : adaptor.getOperands().drop_front()) { - auto size = rewriter.create(loc, arg, axisValue); - resultDimSize = rewriter.create(loc, resultDimSize, size); + auto size = rewriter.createOrFold(loc, arg, axisValue); + resultDimSize = + rewriter.createOrFold(loc, resultDimSize, size); } sizes[axis] = resultDimSize; Value init = rewriter.create( loc, resultType.getShape(), resultType.getElementType()); - Value zeroVal = rewriter.create( + Value zeroVal = rewriter.createOrFold( loc, rewriter.getZeroAttr(resultType.getElementType())); Value result = rewriter.create(loc, zeroVal, init).getResult(0); + auto toOpFoldResult = [](Value v) -> OpFoldResult { + auto op = v.getDefiningOp(); + if (!op) + return v; + return op.getValue(); + }; for (auto arg : adaptor.getOperands()) { - sizes[axis] = rewriter.create(loc, arg, axisValue); - result = rewriter.create(loc, arg, result, offsets, - sizes, strides); + sizes[axis] = rewriter.createOrFold(loc, arg, axisValue); + result = rewriter.createOrFold( + loc, arg, result, + llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), + llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), + llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); offsets[axis] = - rewriter.create(loc, offsets[axis], sizes[axis]); + rewriter.createOrFold(loc, offsets[axis], sizes[axis]); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -835,16 +835,23 @@ //===----------------------------------------------------------------------===// // InitTensorOp //===----------------------------------------------------------------------===// + +static IntegerAttr extractConstantInteger(Value v) { + // TODO: Decide whether we want to activate more aggressive static constant + // detection. + // auto op = v.getDefiningOp(); + // if (!op) return IntegerAttr(); + // return op.getValue().dyn_cast(); + return IntegerAttr(); +} + void InitTensorOp::build(OpBuilder &b, OperationState &result, ArrayRef sizes, Type elementType, ArrayRef attrs) { - unsigned rank = sizes.size(); SmallVector dynamicSizes; SmallVector staticSizes; - for (unsigned i = 0; i < rank; ++i) { - dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes, - ShapedType::kDynamicSize); - } + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, + extractConstantInteger, ShapedType::kDynamicSize); auto resultType = RankedTensorType ::get(staticSizes, elementType); build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes)); result.addAttributes(attrs); @@ -1127,19 +1134,16 @@ ArrayRef attrs) { assert(resultType.isa()); auto sourceType = source.getType().cast(); - unsigned rank = sourceType.getRank(); SmallVector dynamicLow, dynamicHigh; SmallVector staticLow, staticHigh; - for (unsigned i = 0; i < rank; ++i) { - // staticLow and staticHigh have full information of the padding config. - // This will grow staticLow and staticHigh with 1 value. If the config is - // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 - // value as well. - dispatchIndexOpFoldResult(low[i], dynamicLow, staticLow, - ShapedType::kDynamicSize); - dispatchIndexOpFoldResult(high[i], dynamicHigh, staticHigh, - ShapedType::kDynamicSize); - } + // staticLow and staticHigh have full information of the padding config. + // This will grow staticLow and staticHigh with 1 value. If the config is + // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 + // value as well. + dispatchIndexOpFoldResults(low, dynamicLow, staticLow, extractConstantInteger, + ShapedType::kDynamicSize); + dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh, + extractConstantInteger, ShapedType::kDynamicSize); if (!resultType) { resultType = PadTensorOp::inferResultType(sourceType, staticLow, staticHigh); @@ -1393,7 +1397,8 @@ // The shape of the result can be obtained from the sizes passed in. SmallVector dynDims; SmallVector shape; - dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize); + dispatchIndexOpFoldResults(sizes, dynDims, shape, extractConstantInteger, + ShapedType::kDynamicSize); RankedTensorType resultType = RankedTensorType::get(shape, getResultType().getElementType()); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -504,27 +504,31 @@ return numOccurences; } -/// Given the type of the un-rank reduced subview result type and the -/// rank-reduced result type, computes the dropped dimensions. This accounts for -/// cases where there are multiple unit-dims, but only a subset of those are -/// dropped. For MemRefTypes these can be disambiguated using the strides. If a -/// dimension is dropped the stride must be dropped too. +/// Given the `originalType` and a `candidateReducedType` whose shape is assumed +/// to be a subset of `originalType` with some `1` entries erased, return the +/// set of indices that specifies which of the entries of `originalShape` are +/// dropped to obtain `reducedShape`. +/// This accounts for cases where there are multiple unit-dims, but only a +/// subset of those are dropped. For MemRefTypes these can be disambiguated +/// using the strides. If a dimension is dropped the stride must be dropped too. static llvm::Optional> -computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, +computeMemRefRankReductionMask(MemRefType originalType, + MemRefType candidateReducedType, ArrayAttr staticSizes) { llvm::SmallDenseSet unusedDims; - if (originalType.getRank() == reducedType.getRank()) + if (originalType.getRank() == candidateReducedType.getRank()) return unusedDims; for (auto dim : llvm::enumerate(staticSizes)) if (dim.value().cast().getInt() == 1) unusedDims.insert(dim.index()); + SmallVector originalStrides, candidateStrides; int64_t originalOffset, candidateOffset; if (failed( getStridesAndOffset(originalType, originalStrides, originalOffset)) || - failed( - getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) + failed(getStridesAndOffset(candidateReducedType, candidateStrides, + candidateOffset))) return llvm::None; // For memrefs, a dimension is truly dropped if its corresponding stride is @@ -565,8 +569,11 @@ for (auto prunedDim : prunedUnusedDims) unusedDims.erase(prunedDim); - if (unusedDims.size() + reducedType.getRank() != originalType.getRank()) + + if (unusedDims.size() + candidateReducedType.getRank() != + originalType.getRank()) return llvm::None; + return unusedDims; } @@ -1072,6 +1079,15 @@ // ReinterpretCastOp //===----------------------------------------------------------------------===// +static IntegerAttr extractConstantInteger(Value v) { + // TODO: Decide whether we want to activate more aggressive static constant + // detection. + // auto op = v.getDefiningOp(); + // if (!op) return IntegerAttr(); + // return op.getValue().dyn_cast(); + return IntegerAttr(); +} + /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`, /// `staticSizes` and `staticStrides` are automatically filled with /// source-memref-rank sentinel values that encode dynamic entries. @@ -1083,10 +1099,12 @@ SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets, + extractConstantInteger, ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); + extractConstantInteger, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, + extractConstantInteger, ShapedType::kDynamicStrideOrOffset); build(b, result, resultType, source, dynamicOffsets, dynamicSizes, dynamicStrides, b.getI64ArrayAttr(staticOffsets), @@ -1534,14 +1552,15 @@ SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, - staticOffsets, ShapedType::kDynamicStrideOrOffset); + staticOffsets, extractConstantInteger, + ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); + extractConstantInteger, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, - staticStrides, ShapedType::kDynamicStrideOrOffset); + staticStrides, extractConstantInteger, + ShapedType::kDynamicStrideOrOffset); return SubViewOp::inferResultType(sourceMemRefType, staticOffsets, - staticSizes, staticStrides) - .cast(); + staticSizes, staticStrides); } Type SubViewOp::inferRankReducedResultType( @@ -1582,11 +1601,13 @@ SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, - staticOffsets, ShapedType::kDynamicStrideOrOffset); + staticOffsets, extractConstantInteger, + ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); + extractConstantInteger, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, - staticStrides, ShapedType::kDynamicStrideOrOffset); + staticStrides, extractConstantInteger, + ShapedType::kDynamicStrideOrOffset); return SubViewOp::inferRankReducedResultType( resultRank, sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); @@ -1602,10 +1623,12 @@ SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, + extractConstantInteger, ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); + extractConstantInteger, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, + extractConstantInteger, ShapedType::kDynamicStrideOrOffset); auto sourceMemRefType = source.getType().cast(); // Structuring implementation this way avoids duplication between builders. @@ -1698,87 +1721,61 @@ /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { return source(); } -enum SubViewVerificationResult { - Success, - RankTooLarge, - SizeMismatch, - ElemTypeMismatch, - MemSpaceMismatch, - AffineMapMismatch -}; - /// Checks if `original` Type type can be rank reduced to `reduced` type. /// This function is slight variant of `is subsequence` algorithm where /// not matching dimension must be 1. -static SubViewVerificationResult -isRankReducedType(Type originalType, Type candidateReducedType, - ArrayAttr staticSizes, std::string *errMsg = nullptr) { - if (originalType == candidateReducedType) - return SubViewVerificationResult::Success; - if (!originalType.isa()) - return SubViewVerificationResult::Success; - if (originalType.isa() && !candidateReducedType.isa()) - return SubViewVerificationResult::Success; - - ShapedType originalShapedType = originalType.cast(); - ShapedType candidateReducedShapedType = - candidateReducedType.cast(); - - // Rank and size logic is valid for all ShapedTypes. - ArrayRef originalShape = originalShapedType.getShape(); - ArrayRef candidateReducedShape = - candidateReducedShapedType.getShape(); - unsigned originalRank = originalShape.size(), - candidateReducedRank = candidateReducedShape.size(); - if (candidateReducedRank > originalRank) - return SubViewVerificationResult::RankTooLarge; +static SliceVerificationResult +isRankReducedMemRefType(MemRefType originalType, + MemRefType candidatecandidateReducedType, + ArrayAttr staticSizes, std::string *errMsg = nullptr) { + auto partialRes = + isRankReducedType(originalType, candidatecandidateReducedType, errMsg); + if (partialRes != SliceVerificationResult::Success) + return partialRes; MemRefType original = originalType.cast(); - MemRefType candidateReduced = candidateReducedType.cast(); + MemRefType candidateReduced = + candidatecandidateReducedType.cast(); auto optionalUnusedDimsMask = computeMemRefRankReductionMask(original, candidateReduced, staticSizes); // Sizes cannot be matched in case empty vector is returned. if (!optionalUnusedDimsMask.hasValue()) - return SubViewVerificationResult::SizeMismatch; - - if (originalShapedType.getElementType() != - candidateReducedShapedType.getElementType()) - return SubViewVerificationResult::ElemTypeMismatch; + return SliceVerificationResult::LayoutMismatch; - // Strided layout logic is relevant for MemRefType only. if (original.getMemorySpace() != candidateReduced.getMemorySpace()) - return SubViewVerificationResult::MemSpaceMismatch; - return SubViewVerificationResult::Success; + return SliceVerificationResult::MemSpaceMismatch; + + return SliceVerificationResult::Success; } template -static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, +static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType, StringRef errMsg = "") { auto memrefType = expectedType.cast(); switch (result) { - case SubViewVerificationResult::Success: + case SliceVerificationResult::Success: return success(); - case SubViewVerificationResult::RankTooLarge: + case SliceVerificationResult::RankTooLarge: return op.emitError("expected result rank to be smaller or equal to ") << "the source rank. " << errMsg; - case SubViewVerificationResult::SizeMismatch: + case SliceVerificationResult::SizeMismatch: return op.emitError("expected result type to be ") << expectedType << " or a rank-reduced version. (mismatch of result sizes) " << errMsg; - case SubViewVerificationResult::ElemTypeMismatch: + case SliceVerificationResult::ElemTypeMismatch: return op.emitError("expected result element type to be ") << memrefType.getElementType() << errMsg; - case SubViewVerificationResult::MemSpaceMismatch: + case SliceVerificationResult::MemSpaceMismatch: return op.emitError("expected result and source memory spaces to match.") << errMsg; - case SubViewVerificationResult::AffineMapMismatch: + case SliceVerificationResult::LayoutMismatch: return op.emitError("expected result type to be ") << expectedType - << " or a rank-reduced version. (mismatch of result affine map) " + << " or a rank-reduced version. (mismatch of result layout) " << errMsg; } llvm_unreachable("unexpected subview verification result"); @@ -1806,8 +1803,8 @@ extractFromI64ArrayAttr(op.static_strides())); std::string errMsg; - auto result = - isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg); + auto result = isRankReducedMemRefType( + expectedType.cast(), subViewType, op.static_sizes(), &errMsg); return produceSubViewErrorMsg(result, op, expectedType, errMsg); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -655,10 +656,11 @@ /// An extract_slice op result type can be fully inferred from the source type /// and the static representation of offsets, sizes and strides. Special /// sentinels encode the dynamic case. -Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, - ArrayRef leadingStaticOffsets, - ArrayRef leadingStaticSizes, - ArrayRef leadingStaticStrides) { +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { // An extract_slice op may specify only a leading subset of offset/sizes/ // strides in which case we complete with offset=0, sizes from memref type and // strides=1. @@ -673,19 +675,31 @@ sourceRankedTensorType.getElementType()); } -Type ExtractSliceOp::inferResultType( - RankedTensorType sourceRankedTensorType, - ArrayRef leadingStaticOffsets, - ArrayRef leadingStaticSizes, - ArrayRef leadingStaticStrides) { +static IntegerAttr extractConstantInteger(Value v) { + // TODO: Decide whether we want to activate more aggressive static constant + // detection. + // auto op = v.getDefiningOp(); + // if (!op) + // return IntegerAttr(); + // return op.getValue().dyn_cast(); + return IntegerAttr(); +} + +RankedTensorType +ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, - staticOffsets, ShapedType::kDynamicStrideOrOffset); + staticOffsets, extractConstantInteger, + ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); + extractConstantInteger, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, - staticStrides, ShapedType::kDynamicStrideOrOffset); + staticStrides, extractConstantInteger, + ShapedType::kDynamicStrideOrOffset); return ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); } @@ -693,7 +707,7 @@ /// An extract_slice op result type can be fully inferred from the source type /// and the static representation of offsets, sizes and strides. Special /// sentinels encode the dynamic case. -Type ExtractSliceOp::inferRankReducedResultType( +RankedTensorType ExtractSliceOp::inferRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, ArrayRef leadingStaticOffsets, ArrayRef leadingStaticSizes, @@ -717,7 +731,7 @@ return inferredType; } -Type ExtractSliceOp::inferRankReducedResultType( +RankedTensorType ExtractSliceOp::inferRankReducedResultType( unsigned resultRank, RankedTensorType sourceRankedTensorType, ArrayRef leadingStaticOffsets, ArrayRef leadingStaticSizes, @@ -725,11 +739,13 @@ SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets, - staticOffsets, ShapedType::kDynamicStrideOrOffset); + staticOffsets, extractConstantInteger, + ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); + extractConstantInteger, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides, - staticStrides, ShapedType::kDynamicStrideOrOffset); + staticStrides, extractConstantInteger, + ShapedType::kDynamicStrideOrOffset); return ExtractSliceOp::inferRankReducedResultType( resultRank, sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); @@ -746,10 +762,12 @@ SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, + extractConstantInteger, ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); + extractConstantInteger, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, + extractConstantInteger, ShapedType::kDynamicStrideOrOffset); auto sourceRankedTensorType = source.getType().cast(); // Structuring implementation this way avoids duplication between builders. @@ -797,58 +815,6 @@ build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); } -enum SliceVerificationResult { - Success, - RankTooLarge, - SizeMismatch, - ElemTypeMismatch, -}; - -/// Checks if `original` Type type can be rank reduced to `reduced` type. -/// This function is slight variant of `is subsequence` algorithm where -/// not matching dimension must be 1. -static SliceVerificationResult -isRankReducedType(Type originalType, Type candidateReducedType, - std::string *errMsg = nullptr) { - if (originalType == candidateReducedType) - return SliceVerificationResult::Success; - if (!originalType.isa()) - return SliceVerificationResult::Success; - if (originalType.isa() && - !candidateReducedType.isa()) - return SliceVerificationResult::Success; - - ShapedType originalShapedType = originalType.cast(); - ShapedType candidateReducedShapedType = - candidateReducedType.cast(); - - // Rank and size logic is valid for all ShapedTypes. - ArrayRef originalShape = originalShapedType.getShape(); - ArrayRef candidateReducedShape = - candidateReducedShapedType.getShape(); - unsigned originalRank = originalShape.size(), - candidateReducedRank = candidateReducedShape.size(); - if (candidateReducedRank > originalRank) - return SliceVerificationResult::RankTooLarge; - - auto optionalUnusedDimsMask = - computeRankReductionMask(originalShape, candidateReducedShape); - - // Sizes cannot be matched in case empty vector is returned. - if (!optionalUnusedDimsMask.hasValue()) - return SliceVerificationResult::SizeMismatch; - - if (originalShapedType.getElementType() != - candidateReducedShapedType.getElementType()) - return SliceVerificationResult::ElemTypeMismatch; - - // We are done for the tensor case. - if (originalType.isa()) - return SliceVerificationResult::Success; - - return SliceVerificationResult::Success; -} - template static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType, @@ -858,28 +824,28 @@ case SliceVerificationResult::Success: return success(); case SliceVerificationResult::RankTooLarge: - return op.emitError("expected result rank to be smaller or equal to ") - << "the source rank. " << errMsg; + return op.emitError("expected rank to be smaller or equal to ") + << "the other rank. " << errMsg; case SliceVerificationResult::SizeMismatch: - return op.emitError("expected result type to be ") - << expectedType - << " or a rank-reduced version. (mismatch of result sizes) " + return op.emitError("expected type to be ") + << expectedType << " or a rank-reduced version. (size mismatch) " << errMsg; case SliceVerificationResult::ElemTypeMismatch: - return op.emitError("expected result element type to be ") + return op.emitError("expected element type to be ") << memrefType.getElementType() << errMsg; + default: + llvm_unreachable("unexpected extract_slice op verification result"); } - llvm_unreachable("unexpected extract_slice op verification result"); } /// Verifier for ExtractSliceOp. static LogicalResult verify(ExtractSliceOp op) { // Verify result type against inferred type. - auto expectedType = ExtractSliceOp::inferResultType( - op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()), - extractFromI64ArrayAttr(op.static_sizes()), - extractFromI64ArrayAttr(op.static_strides())); - auto result = isRankReducedType(expectedType, op.getType()); + auto expectedType = + ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides()); + auto result = + isRankReducedType(expectedType.cast(), op.getType()); return produceSliceErrorMsg(result, op, expectedType); } @@ -1104,10 +1070,12 @@ SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, + extractConstantInteger, ShapedType::kDynamicStrideOrOffset); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); + extractConstantInteger, ShapedType::kDynamicSize); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, + extractConstantInteger, ShapedType::kDynamicStrideOrOffset); build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes, dynamicStrides, b.getI64ArrayAttr(staticOffsets), @@ -1128,6 +1096,19 @@ build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +/// Verifier for InsertSliceOp. +static LogicalResult verify(InsertSliceOp op) { + // insert_slice is the inverse of extract_slice, use the same type inference. + auto expectedType = ExtractSliceOp::inferRankReducedResultType( + op.getSourceType().getRank(), op.getType(), + extractFromI64ArrayAttr(op.static_offsets()), + extractFromI64ArrayAttr(op.static_sizes()), + extractFromI64ArrayAttr(op.static_strides())); + auto result = + isRankReducedType(expectedType.cast(), op.getSourceType()); + return produceSliceErrorMsg(result, op, expectedType); +} + /// If we have two consecutive InsertSliceOp writing to the same slice, we /// can mutate the second InsertSliceOp's destination to the first one's. /// @@ -1202,9 +1183,16 @@ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. - rewriter.replaceOpWithNewOp( - insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(), + auto sourceType = ExtractSliceOp::inferRankReducedResultType( + insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(), mixedOffsets, mixedSizes, mixedStrides); + Value toInsert = insertSliceOp.source(); + if (sourceType != insertSliceOp.getSourceType()) + toInsert = rewriter.create(insertSliceOp.getLoc(), + sourceType, toInsert); + rewriter.replaceOpWithNewOp( + insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes, + mixedStrides); return success(); } }; diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -13,30 +13,38 @@ namespace mlir { -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// Helper function to dispatch an OpFoldResult into `staticVec` if: +/// a) it is an IntegerAttr or +/// b) it is possible to extract an IntegerAttr by extractConstantInteger +/// In other cases, the OpFoldResult is dispached to the `dynamicVec`. +/// In such dynamic cases, a copy of the `sentinel` value is also pushed to /// `staticVec`. This is useful to extract mixed static and dynamic entries that /// come from an AttrSizedOperandSegments trait. -void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); +void dispatchIndexOpFoldResult( + OpFoldResult ofr, SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + llvm::function_ref extractConstantInteger, + int64_t sentinel) { + auto v = ofr.dyn_cast(); + IntegerAttr iAttr = + v ? extractConstantInteger(v) : ofr.get().cast(); + if (iAttr) { + APInt apInt = iAttr.getValue(); + staticVec.push_back(apInt.getSExtValue()); return; } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); + dynamicVec.push_back(v); + staticVec.push_back(sentinel); } -void dispatchIndexOpFoldResults(ArrayRef ofrs, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { +void dispatchIndexOpFoldResults( + ArrayRef ofrs, SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + llvm::function_ref extractConstantInteger, + int64_t sentinel) { for (OpFoldResult ofr : ofrs) - dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); + dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, + extractConstantInteger, sentinel); } /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -571,7 +571,7 @@ llvm::SmallDenseSet unusedDims; unsigned reducedIdx = 0; for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { - // Greedily insert `originalIdx` if no match. + // Greedily insert `originalIdx` if match. if (reducedIdx < reducedRank && originalShape[originalIdx] == reducedShape[reducedIdx]) { reducedIdx++; @@ -590,6 +590,42 @@ return unusedDims; } +/// Checks if `originalType` can be rank reduced to `reduced` type. +/// This function is a slight variant of the `is subsequence` algorithm where +/// not matching dimension must be 1. +SliceVerificationResult mlir::isRankReducedType(ShapedType originalType, + ShapedType candidateReducedType, + std::string *errMsg) { + if (originalType == candidateReducedType) + return SliceVerificationResult::Success; + + ShapedType originalShapedType = originalType.cast(); + ShapedType candidateReducedShapedType = + candidateReducedType.cast(); + + // Rank and size logic is valid for all ShapedTypes. + ArrayRef originalShape = originalShapedType.getShape(); + ArrayRef candidateReducedShape = + candidateReducedShapedType.getShape(); + unsigned originalRank = originalShape.size(), + candidateReducedRank = candidateReducedShape.size(); + if (candidateReducedRank > originalRank) + return SliceVerificationResult::RankTooLarge; + + auto optionalUnusedDimsMask = + computeRankReductionMask(originalShape, candidateReducedShape); + + // Sizes cannot be matched in case empty vector is returned. + if (!optionalUnusedDimsMask.hasValue()) + return SliceVerificationResult::SizeMismatch; + + if (originalShapedType.getElementType() != + candidateReducedShapedType.getElementType()) + return SliceVerificationResult::ElemTypeMismatch; + + return SliceVerificationResult::Success; +} + bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { // Empty attribute is allowed as default memory space. if (!memorySpace) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -820,38 +820,24 @@ // CHECK: [[STRIDE:%.+]] = arith.constant 1 // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index // CHECK: [[IDX0:%.+]] = arith.constant 0 : index - // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]] // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]] - // CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg1, [[AXIS]] - // CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM0]], [[ARG1_AXIS]] // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1] // CHECK: [[CST:%.+]] = arith.constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]]) - // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[AXIS]] - // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] - // CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM0]] - // CHECK: [[ARG1_DIM0:%.+]] = tensor.dim %arg1, [[AXIS]] - // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1] + // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]][5, 0] [6, 1] [1, 1] %0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>) // CHECK: [[AXIS:%.+]] = arith.constant 1 // CHECK: [[STRIDE:%.+]] = arith.constant 1 // CHECK: [[OFFSET:%.+]] = arith.constant 0 : index // CHECK: [[IDX0:%.+]] = arith.constant 0 : index - // CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]] // CHECK: [[IDX1:%.+]] = arith.constant 1 : index - // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]] - // CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg0, [[AXIS]] - // CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM1]], [[ARG1_AXIS]] // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2] // CHECK: [[CST:%.+]] = arith.constant 0.0 // CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]]) - // CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]] - // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] - // CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM1]] - // CHECK: [[ARG1_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]] - // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + // CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1] + // CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]][0, 1] [5, 1] [1, 1] %1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>) return } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -428,7 +428,9 @@ %A : tensor, %B : tensor {linalg.inplaceable = true}, %C : tensor {linalg.inplaceable = true}, - %idx : index) + %idx : index, + %sz1 : index, + %sz2 : index) -> (tensor, tensor, tensor) { %f0 = arith.constant 0.0 : f32 @@ -497,9 +499,9 @@ // CHECK-NEXT: tensor.insert_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"]} %sC = tensor.extract_slice %C[0, 0][%idx, %idx][1, 1] : tensor to tensor - %ssC = tensor.extract_slice %sC[0, 0][4, 4][1, 1] : tensor to tensor<4x4xf32> - %FC = linalg.fill(%f0, %ssC) : f32, tensor<4x4xf32> -> tensor<4x4xf32> - %rsC = tensor.insert_slice %FC into %sC[0, 0][12345, 67890][1, 1] : tensor<4x4xf32> into tensor + %ssC = tensor.extract_slice %sC[0, 0][%sz1, 4][1, 1] : tensor to tensor + %FC = linalg.fill(%f0, %ssC) : f32, tensor -> tensor + %rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor into tensor %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor into tensor return %rA, %rB, %rC: tensor, tensor, tensor diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -592,7 +592,7 @@ linalg.yield %1 : f32 } -> tensor<4xf32> - %sum_sub = tensor.insert_slice %acc into %o_[%j][%c4][1] + %sum_sub = tensor.insert_slice %acc into %o_[%j][4][1] : tensor<4xf32> into tensor<24xf32> linalg.yield %sum_sub : tensor<24xf32> } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file %s -verify-diagnostics +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file %s -verify-diagnostics func @dma_start_not_enough_operands() { // expected-error@+1 {{expected at least 4 operands}} @@ -488,9 +488,323 @@ // ----- +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<2048xi8> + // expected-error@+1 {{expects 1 offset operand}} + %1 = memref.view %0[][%arg0, %arg1] + : memref<2048xi8> to memref + return +} + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<2048xi8, affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)>> + // expected-error@+1 {{unsupported map for base memref type}} + %1 = memref.view %0[%arg2][%arg0, %arg1] + : memref<2048xi8, affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)>> to + memref (d0 * 4 + d1 + s0)>> + return +} + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<2048xi8> + // expected-error@+1 {{unsupported map for result memref type}} + %1 = memref.view %0[%arg2][%arg0, %arg1] + : memref<2048xi8> to memref (d0, d1, s0)>> + return +} + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<2048xi8, 2> + // expected-error@+1 {{different memory spaces}} + %1 = memref.view %0[%arg2][%arg0, %arg1] : memref<2048xi8, 2> to memref + return +} + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<2048xi8> + // expected-error@+1 {{incorrect number of size operands for type}} + %1 = memref.view %0[%arg2][%arg0] + : memref<2048xi8> to memref + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank (2 vs 3) so the rank of the result type is well-formed}} + %1 = memref.subview %0[0, 0][2, 2, 2][1, 1, 1] + : memref<8x16x4xf32> to memref<8x16x4xf32> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected mixed sizes rank to match mixed strides rank (3 vs 2) so the rank of the result type is well-formed}} + %1 = memref.subview %0[0, 0, 0][2, 2, 2][1, 1] + : memref<8x16x4xf32> to memref<8x16x4xf32> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected mixed sizes rank to match mixed strides rank (3 vs 2) so the rank of the result type is well-formed}} + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [2, 2, 2], strides:[1, 1] + : memref<8x16x4xf32> to memref<8x16x4xf32> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> + // expected-error@+1 {{different memory spaces}} + %1 = memref.subview %0[0, 0, 0][%arg2, %arg2, %arg2][1, 1, 1] + : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> to + memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> + // expected-error@+1 {{is not strided}} + %1 = memref.subview %0[0, 0, 0][%arg2, %arg2, %arg2][1, 1, 1] + : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> to + memref<8x?x4xf32, offset: 0, strides: [?, 4, 1]> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected <= 3 offset values}} + %1 = memref.subview %0[%arg0, %arg1, 0, 0][%arg2, 0, 0, 0][1, 1, 1, 1] + : memref<8x16x4xf32> to + memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected result element type to be 'f32'}} + %1 = memref.subview %0[0, 0, 0][8, 16, 4][1, 1, 1] + : memref<8x16x4xf32> to + memref<8x16x4xi32> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected result rank to be smaller or equal to the source rank.}} + %1 = memref.subview %0[0, 0, 0][8, 16, 4][1, 1, 1] + : memref<8x16x4xf32> to + memref<8x16x4x3xi32> + return +} + +// ----- + +func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}} + %1 = memref.subview %0[0, 0, 0][8, 16, 4][1, 1, 1] + : memref<8x16x4xf32> to memref<16x4xf32> + return +} + +// ----- + +func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = memref.alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>>' or a rank-reduced version. (mismatch of result sizes)}} + %1 = memref.subview %0[0, 2, 0][8, 16, 4][1, 1, 1] + : memref<8x16x4xf32> to memref<16x4xf32> + return +} + +// ----- + +func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result layout)}} + %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref + return +} + +// ----- + func @static_stride_to_dynamic_stride(%arg0 : memref, %arg1 : index, %arg2 : index) -> memref { - // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}} + // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result layout)}} %0 = memref.subview %arg0[0, 0, 0] [1, %arg1, %arg2] [1, 1, 1] : memref to memref return %0 : memref } + +// ----- + +func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { + // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} + %0 = memref.cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]> + return +} + +// ----- + +func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { + // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2 + 16)>>' are cast incompatible}} + %0 = memref.cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]> + return +} + +// ----- + +// incompatible element types +func @invalid_memref_cast() { + %0 = memref.alloc() : memref<2x5xf32, 0> + // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xi32>' are cast incompatible}} + %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xi32> + return +} + +// ----- + +func @invalid_prefetch_rw(%i : index) { + %0 = memref.alloc() : memref<10xf32> + // expected-error@+1 {{rw specifier has to be 'read' or 'write'}} + memref.prefetch %0[%i], rw, locality<0>, data : memref<10xf32> + return +} + +// ----- + +func @invalid_prefetch_cache_type(%i : index) { + %0 = memref.alloc() : memref<10xf32> + // expected-error@+1 {{cache type has to be 'data' or 'instr'}} + memref.prefetch %0[%i], read, locality<0>, false : memref<10xf32> + return +} + +// ----- + +func @invalid_prefetch_locality_hint(%i : index) { + %0 = memref.alloc() : memref<10xf32> + // expected-error@+1 {{32-bit signless integer attribute whose minimum value is 0 whose maximum value is 3}} + memref.prefetch %0[%i], read, locality<5>, data : memref<10xf32> + return +} + +// ----- + +// incompatible memory space +func @invalid_memref_cast() { + %0 = memref.alloc() : memref<2x5xf32, 0> + // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}} + %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 1> + return +} + +// ----- + +// unranked to unranked +func @invalid_memref_cast() { + %0 = memref.alloc() : memref<2x5xf32, 0> + %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0> + // expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}} + %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 0> + return +} + +// ----- + +// alignment is not power of 2. +func @assume_alignment(%0: memref<4x4xf16>) { + // expected-error@+1 {{alignment must be power of 2}} + memref.assume_alignment %0, 12 : memref<4x4xf16> + return +} + +// ----- + +// 0 alignment value. +func @assume_alignment(%0: memref<4x4xf16>) { + // expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} + memref.assume_alignment %0, 0 : memref<4x4xf16> + return +} + +// ----- + +"alloca_without_scoped_alloc_parent"() ( { + memref.alloca() : memref<1xf32> + // expected-error@-1 {{requires an ancestor op with AutomaticAllocationScope trait}} + return +}) : () -> () + +// ----- + +func @bad_alloc_wrong_dynamic_dim_count() { +^bb0: + %0 = arith.constant 7 : index + // Test alloc with wrong number of dynamic dimensions. + // expected-error@+1 {{dimension operand count does not equal memref dynamic dimension count}} + %1 = memref.alloc(%0)[%0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> + return +} + +// ----- + +func @bad_alloc_wrong_symbol_count() { +^bb0: + %0 = arith.constant 7 : index + // Test alloc with wrong number of symbols + // expected-error@+1 {{symbol operand count does not equal memref symbol count}} + %1 = memref.alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> + return +} + +// ----- + +func @test_store_zero_results() { +^bb0: + %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + %1 = arith.constant 0 : index + %2 = arith.constant 1 : index + %3 = memref.load %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> + // Test that store returns zero results. + %4 = memref.store %3, %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> // expected-error {{cannot name an operation with no results}} + return +} + +// ----- + +func @test_store_zero_results2(%x: i32, %p: memref) { + "memref.store"(%x,%p) : (i32, memref) -> i32 // expected-error {{'memref.store' op requires zero results}} + return +} + +// ----- + +func @test_alloc_memref_map_rank_mismatch() { +^bb0: + // expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}} + %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> + return +} diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir --- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir +++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir @@ -250,7 +250,7 @@ // CHECK: scf.for // CHECK: tensor.dim %[[t]] func @tensor_dim_of_iter_arg_insertslice(%t : tensor, - %t2 : tensor) -> index { + %t2 : tensor<10x10xf32>) -> index { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index @@ -258,9 +258,9 @@ -> (tensor, index) { %dim = tensor.dim %arg0, %c0 : tensor %2 = tensor.insert_slice %t2 into %arg0[0, 0] [10, 10] [1, 1] - : tensor into tensor + : tensor<10x10xf32> into tensor %3 = tensor.insert_slice %t2 into %2[1, 1] [10, 10] [1, 1] - : tensor into tensor + : tensor<10x10xf32> into tensor scf.yield %3, %dim : tensor, index } return %1 : index @@ -274,7 +274,7 @@ // CHECK: scf.for // CHECK: tensor.dim %[[t]] func @tensor_dim_of_iter_arg_nested_for(%t : tensor, - %t2 : tensor) -> index { + %t2 : tensor<10x10xf32>) -> index { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index @@ -284,7 +284,7 @@ -> (tensor, index) { %dim = tensor.dim %arg2, %c0 : tensor %4 = tensor.insert_slice %t2 into %arg2[0, 0] [10, 10] [1, 1] - : tensor into tensor + : tensor<10x10xf32> into tensor scf.yield %4, %dim : tensor, index } scf.yield %2, %3 : tensor, index @@ -292,6 +292,7 @@ return %1 : index } + // ----- // A test case that should not canonicalize because the loop is not shape diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -348,8 +348,10 @@ // CHECK-NOT: tensor.cast // CHECK: return %[[S]] : tensor<4x6x16x32xi8> func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { + %c0 = arith.constant 0: index %cast = tensor.cast %a : tensor<16x32xi8> to tensor - %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor into tensor<4x6x16x32xi8> + %sz = tensor.dim %cast, %c0: tensor + %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, %sz] [1, 1, 1] : tensor into tensor<4x6x16x32xi8> return %res : tensor<4x6x16x32xi8> } @@ -408,9 +410,10 @@ } // CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize // CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] +// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<4x?xf32> +// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]] // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] -// CHECK-SAME: : tensor into tensor +// CHECK-SAME: : tensor<4x?xf32> into tensor // CHEKC: return %[[RESULT]] // ----- @@ -450,7 +453,7 @@ ^bb0(%arg4: index, %arg5: index): tensor.yield %1 : i32 } : tensor - %3 = tensor.insert_slice %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor + %3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor return %3 : tensor } // CHECK-LABEL: func @insert_slice_propagate_dest_cast @@ -462,9 +465,6 @@ // ----- func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor) -> tensor<3x9xi32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index %c9 = arith.constant 9 : index %c3 = arith.constant 3 : index %2 = tensor.extract %arg1[] : tensor @@ -472,7 +472,7 @@ ^bb0(%arg2: index, %arg3: index): tensor.yield %2 : i32 } : tensor - %5 = tensor.insert_slice %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor + %5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor %6 = tensor.cast %5 : tensor to tensor<3x9xi32> return %6 : tensor<3x9xi32> } @@ -527,8 +527,9 @@ // CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor // CHECK: return %[[r]] func @insert_tensor_cast_on_insert_slice_src( - %arg0 : tensor, %arg1 : tensor) -> tensor { - %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1] + %arg0 : tensor, %arg1 : tensor, %sz0: index, %sz2: index) -> tensor { + %c64 = arith.constant 64: index + %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1] : tensor into tensor return %r : tensor } @@ -559,13 +560,3 @@ // CHECK: return %[[INSERT]] return %1 : tensor } - -// ----- - -// CHECK-LABEL: func @folding_incorrect_ir_triggers_infinite_loop -func @folding_incorrect_ir_triggers_infinite_loop( - %A : tensor<4x4xf32>, %C : tensor) -> tensor { - %rC = tensor.insert_slice %A into %C[0, 0][12345, 67890][1, 1] : - tensor<4x4xf32> into tensor - return %rC: tensor -} diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -1,5 +1,13 @@ // RUN: mlir-opt <%s -split-input-file -verify-diagnostics +func @dim(%arg : tensor<1x?xf32>) { + %c2 = arith.constant 2 : index + tensor.dim %arg, %c2 : tensor<1x?xf32> // expected-error {{'tensor.dim' op index is out of range}} + return +} + +// ----- + func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) { // expected-error@+1 {{operand type 'tensor<1xf32>' and result type 'tensor<2xf32>' are cast incompatible}} %0 = tensor.cast %arg0 : tensor<1xf32> to tensor<2xf32> @@ -138,3 +146,79 @@ tensor.reshape %buf(%shape) : (tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32> } + +// ----- + +func @extract_slice_wrong_result_rank(%t: tensor, %idx : index) { + // expected-error @+1 {{expected rank to be smaller or equal to the other rank.}} + %0 = tensor.extract_slice %t[0][4][1] : tensor to tensor + + return +} + +// ----- + +func @extract_slice_wrong_result_rank(%t: tensor, %idx : index) { + // expected-error @+1 {{expected element type to be 'f32'}} + %0 = tensor.extract_slice %t[0][4][1] : tensor to tensor<4xi8> + + return +} + +// ----- + +func @extract_slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) { + // expected-error @+1 {{expected type to be 'tensor' or a rank-reduced version. (size mismatch)}} + %0 = tensor.extract_slice %t[0, 0, 0][%idx, 4, 4][1, 1, 1] + : tensor<8x16x4xf32> to tensor<4x4x4xf32> + + return +} + +// ----- + +func @extract_slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) { + // expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}} + %0 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1] + : tensor<8x16x4xf32> to tensor + + return +} + +// ----- + +func @insert_slice_wrong_result_rank(%t1: tensor, %t2: tensor, %idx : index) { + // expected-error @+1 {{expected rank to be smaller or equal to the other rank.}} + %0 = tensor.insert_slice %t2 into %t1[0][4][1] : tensor into tensor + + return +} + +// ----- + +func @insert_slice_wrong_result_rank(%t1: tensor<4xi8>, %t2: tensor, %idx : index) { + // expected-error @+1 {{expected element type to be 'f32'}} + %0 = tensor.insert_slice %t1 into %t2[0][4][1] : tensor<4xi8> into tensor + + return +} + +// ----- + +func @insert_slice_wrong_static_type(%t1: tensor<4x4x4xf32>, %t2: tensor<8x16x4xf32>, %idx : index) { + // expected-error @+1 {{expected type to be 'tensor' or a rank-reduced version. (size mismatch)}} + %0 = tensor.insert_slice %t1 into %t2[0, 0, 0][%idx, 4, 4][1, 1, 1] + : tensor<4x4x4xf32> into tensor<8x16x4xf32> + + return +} + +// ----- + +func @insert_slice_wrong_dynamic_type(%t1: tensor, %t2: tensor<8x16x4xf32>, %idx : index) { + // expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}} + %0 = tensor.insert_slice %t1 into %t2[0, 2, 0][4, 4, 4][1, 1, 1] + : tensor into tensor<8x16x4xf32> + + return +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -78,3 +78,60 @@ : (tensor, tensor) -> tensor<*xf32> return %new_unranked : tensor<*xf32> } + +// CHECK-LABEL: func @slice({{.*}}) { +func @slice(%t: tensor<8x16x4xf32>, %idx : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: tensor.extract_slice + // CHECK-SAME: tensor<8x16x4xf32> to tensor + %1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1] + : tensor<8x16x4xf32> to tensor + + // CHECK: tensor.extract_slice + // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32> + %2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1] + : tensor<8x16x4xf32> to tensor<4x4x4xf32> + + // CHECK: tensor.extract_slice + // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32> + %3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1] + : tensor<8x16x4xf32> to tensor<4x4xf32> + + return +} + +// CHECK-LABEL: func @insert_slice({{.*}}) { +func @insert_slice( + %t: tensor<8x16x4xf32>, + %td: tensor<8x?x4xf32>, + %t2: tensor<16x32x8xf32>, + %t3: tensor<4x4xf32>, + %idx : index, + %sz : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: tensor.insert_slice + // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32> + %1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][8, 16, 4][%c1, %c1, %c1] + : tensor<8x16x4xf32> into tensor<16x32x8xf32> + + // CHECK: tensor.insert_slice + // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32> + %2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][8, 16, 4][%c1, 1, %c1] + : tensor<8x16x4xf32> into tensor<16x32x8xf32> + + // CHECK: tensor.insert_slice + // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32> + %3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1] + : tensor<4x4xf32> into tensor<8x16x4xf32> + + // CHECK: tensor.insert_slice + // CHECK-SAME: tensor<8x?x4xf32> into tensor<8x16x4xf32> + %4 = tensor.insert_slice %td into %t[0, %idx, 0][8, %sz, 4][1, 1, 1] + : tensor<8x?x4xf32> into tensor<8x16x4xf32> + + return +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -486,53 +486,3 @@ memref.assume_alignment %0, 16 : memref<4x4xf16> return } - -// CHECK-LABEL: func @slice({{.*}}) { -func @slice(%t: tensor<8x16x4xf32>, %idx : index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - // CHECK: tensor.extract_slice - // CHECK-SAME: tensor<8x16x4xf32> to tensor - %1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1] - : tensor<8x16x4xf32> to tensor - - // CHECK: tensor.extract_slice - // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32> - %2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1] - : tensor<8x16x4xf32> to tensor<4x4x4xf32> - - // CHECK: tensor.extract_slice - // CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32> - %3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1] - : tensor<8x16x4xf32> to tensor<4x4xf32> - - return -} - -// CHECK-LABEL: func @insert_slice({{.*}}) { -func @insert_slice( - %t: tensor<8x16x4xf32>, - %t2: tensor<16x32x8xf32>, - %t3: tensor<4x4xf32>, - %idx : index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - - // CHECK: tensor.insert_slice - // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32> - %1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1] - : tensor<8x16x4xf32> into tensor<16x32x8xf32> - - // CHECK: tensor.insert_slice - // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32> - %2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1] - : tensor<8x16x4xf32> into tensor<16x32x8xf32> - - // CHECK: tensor.insert_slice - // CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32> - %3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1] - : tensor<4x4xf32> into tensor<8x16x4xf32> - - return -} diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1,13 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics -func @dim(%arg : tensor<1x?xf32>) { - %c2 = arith.constant 2 : index - tensor.dim %arg, %c2 : tensor<1x?xf32> // expected-error {{'tensor.dim' op index is out of range}} - return -} - -// ----- - func @rank(f32) { ^bb(%0: f32): "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be any memref or tensor type}} @@ -60,57 +52,6 @@ // ----- -func @bad_alloc_wrong_dynamic_dim_count() { -^bb0: - %0 = arith.constant 7 : index - // Test alloc with wrong number of dynamic dimensions. - // expected-error@+1 {{dimension operand count does not equal memref dynamic dimension count}} - %1 = memref.alloc(%0)[%0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> - return -} - -// ----- - -func @bad_alloc_wrong_symbol_count() { -^bb0: - %0 = arith.constant 7 : index - // Test alloc with wrong number of symbols - // expected-error@+1 {{symbol operand count does not equal memref symbol count}} - %1 = memref.alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> - return -} - -// ----- - -func @test_store_zero_results() { -^bb0: - %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> - %1 = arith.constant 0 : index - %2 = arith.constant 1 : index - %3 = memref.load %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> - // Test that store returns zero results. - %4 = memref.store %3, %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> // expected-error {{cannot name an operation with no results}} - return -} - -// ----- - -func @test_store_zero_results2(%x: i32, %p: memref) { - "memref.store"(%x,%p) : (i32, memref) -> i32 // expected-error {{'memref.store' op requires zero results}} - return -} - -// ----- - -func @test_alloc_memref_map_rank_mismatch() { -^bb0: - // expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}} - %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> - return -} - -// ----- - func @calls(%arg0: i32) { %x = call @calls() : () -> i32 // expected-error {{incorrect number of operands for callee}} return @@ -197,243 +138,6 @@ // ----- -func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<2048xi8> - // expected-error@+1 {{expects 1 offset operand}} - %1 = memref.view %0[][%arg0, %arg1] - : memref<2048xi8> to memref - return -} - -// ----- - -func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<2048xi8, affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)>> - // expected-error@+1 {{unsupported map for base memref type}} - %1 = memref.view %0[%arg2][%arg0, %arg1] - : memref<2048xi8, affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)>> to - memref (d0 * 4 + d1 + s0)>> - return -} - -// ----- - -func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<2048xi8> - // expected-error@+1 {{unsupported map for result memref type}} - %1 = memref.view %0[%arg2][%arg0, %arg1] - : memref<2048xi8> to memref (d0, d1, s0)>> - return -} - -// ----- - -func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<2048xi8, 2> - // expected-error@+1 {{different memory spaces}} - %1 = memref.view %0[%arg2][%arg0, %arg1] : memref<2048xi8, 2> to memref - return -} - -// ----- - -func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<2048xi8> - // expected-error@+1 {{incorrect number of size operands for type}} - %1 = memref.view %0[%arg2][%arg0] - : memref<2048xi8> to memref - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank (2 vs 3) so the rank of the result type is well-formed}} - %1 = memref.subview %0[0, 0][2, 2, 2][1, 1, 1] - : memref<8x16x4xf32> to memref<8x16x4xf32> - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected mixed sizes rank to match mixed strides rank (3 vs 2) so the rank of the result type is well-formed}} - %1 = memref.subview %0[0, 0, 0][2, 2, 2][1, 1] - : memref<8x16x4xf32> to memref<8x16x4xf32> - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected mixed sizes rank to match mixed strides rank (3 vs 2) so the rank of the result type is well-formed}} - %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [2, 2, 2], strides:[1, 1] - : memref<8x16x4xf32> to memref<8x16x4xf32> - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> - // expected-error@+1 {{different memory spaces}} - %1 = memref.subview %0[0, 0, 0][%arg2, %arg2, %arg2][1, 1, 1] - : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> to - memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>> - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> - // expected-error@+1 {{is not strided}} - %1 = memref.subview %0[0, 0, 0][%arg2, %arg2, %arg2][1, 1, 1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> to - memref<8x?x4xf32, offset: 0, strides: [?, 4, 1]> - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected <= 3 offset values}} - %1 = memref.subview %0[%arg0, %arg1, 0, 0][%arg2, 0, 0, 0][1, 1, 1, 1] - : memref<8x16x4xf32> to - memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]> - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result element type to be 'f32'}} - %1 = memref.subview %0[0, 0, 0][8, 16, 4][1, 1, 1] - : memref<8x16x4xf32> to - memref<8x16x4xi32> - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result rank to be smaller or equal to the source rank.}} - %1 = memref.subview %0[0, 0, 0][8, 16, 4][1, 1, 1] - : memref<8x16x4xf32> to - memref<8x16x4x3xi32> - return -} - -// ----- - -func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}} - %1 = memref.subview %0[0, 0, 0][8, 16, 4][1, 1, 1] - : memref<8x16x4xf32> to memref<16x4xf32> - return -} - -// ----- - -func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>>' or a rank-reduced version. (mismatch of result sizes)}} - %1 = memref.subview %0[0, 2, 0][8, 16, 4][1, 1, 1] - : memref<8x16x4xf32> to memref<16x4xf32> - return -} - -// ----- - -func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { - // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result sizes)}} - %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref - return -} - -// ----- - -func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { - // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} - %0 = memref.cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]> - return -} - -// ----- - -func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { - // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2 + 16)>>' are cast incompatible}} - %0 = memref.cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]> - return -} - -// ----- - -// incompatible element types -func @invalid_memref_cast() { - %0 = memref.alloc() : memref<2x5xf32, 0> - // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xi32>' are cast incompatible}} - %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xi32> - return -} - -// ----- - -func @invalid_prefetch_rw(%i : index) { - %0 = memref.alloc() : memref<10xf32> - // expected-error@+1 {{rw specifier has to be 'read' or 'write'}} - memref.prefetch %0[%i], rw, locality<0>, data : memref<10xf32> - return -} - -// ----- - -func @invalid_prefetch_cache_type(%i : index) { - %0 = memref.alloc() : memref<10xf32> - // expected-error@+1 {{cache type has to be 'data' or 'instr'}} - memref.prefetch %0[%i], read, locality<0>, false : memref<10xf32> - return -} - -// ----- - -func @invalid_prefetch_locality_hint(%i : index) { - %0 = memref.alloc() : memref<10xf32> - // expected-error@+1 {{32-bit signless integer attribute whose minimum value is 0 whose maximum value is 3}} - memref.prefetch %0[%i], read, locality<5>, data : memref<10xf32> - return -} - -// ----- - -// incompatible memory space -func @invalid_memref_cast() { - %0 = memref.alloc() : memref<2x5xf32, 0> - // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}} - %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 1> - return -} - -// ----- - -// unranked to unranked -func @invalid_memref_cast() { - %0 = memref.alloc() : memref<2x5xf32, 0> - %1 = memref.cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0> - // expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}} - %2 = memref.cast %1 : memref<*xf32, 0> to memref<*xf32, 0> - return -} - -// ----- - func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) { // expected-error@+1 {{expects the number of subscripts to be equal to memref rank}} %x = atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32 @@ -518,52 +222,6 @@ // ----- -// alignment is not power of 2. -func @assume_alignment(%0: memref<4x4xf16>) { - // expected-error@+1 {{alignment must be power of 2}} - memref.assume_alignment %0, 12 : memref<4x4xf16> - return -} - -// ----- - -// 0 alignment value. -func @assume_alignment(%0: memref<4x4xf16>) { - // expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} - memref.assume_alignment %0, 0 : memref<4x4xf16> - return -} - -// ----- - -"alloca_without_scoped_alloc_parent"() ( { - memref.alloca() : memref<1xf32> - // expected-error@-1 {{requires an ancestor op with AutomaticAllocationScope trait}} - return -}) : () -> () - -// ----- - -func @slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) { - // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}} - %0 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1] - : tensor<8x16x4xf32> to tensor - - return -} - -// ----- - -func @slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) { - // expected-error @+1 {{expected result type to be 'tensor' or a rank-reduced version. (mismatch of result sizes)}} - %0 = tensor.extract_slice %t[0, 0, 0][%idx, 3, %idx][1, 1, 1] - : tensor<8x16x4xf32> to tensor<4x4x4xf32> - - return -} - -// ----- - func @no_zero_bit_integer_attrs() { // expected-error @+1 {{integer constant out of range for attribute}} %x = "some.op"(){value = 0 : i0} : () -> f32