diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1390,18 +1390,16 @@ let extraClassDeclaration = extraBaseClassDeclaration # [{ // The result of the op is always a ranked memref. - MemRefType getType() { return getResult().getType().cast(); } + MemRefType getType() { return getResult().getType(); } Value getViewSource() { return getSource(); } - /// Return the rank of the source ShapedType. - unsigned getResultRank() { - return getResult().getType().cast().getRank(); - } + /// Return the rank of the result type. + unsigned getResultRank() { return getType().getRank(); } /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. std::array getArrayAttrMaxRanks() { - unsigned resultRank = getResult().getType().cast().getRank(); + unsigned resultRank = getType().getRank(); return {1, resultRank, resultRank}; } @@ -1830,8 +1828,7 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic and - ShapedType::kDynamic encodes that the corresponding entry has + sentinel value ShapedType::kDynamic encodes that the corresponding entry has a dynamic value. A subview operation may additionally reduce the rank of the resulting view @@ -2122,7 +2119,6 @@ let extraClassDeclaration = [{ static StringRef getPermutationAttrStrName() { return "permutation"; } - ShapedType getShapedType() { return getIn().getType().cast(); } }]; let hasCustomAssemblyFormat = 1; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -226,7 +226,7 @@ Pure, TypesMatchWith<"result type matches element type of tensor", "tensor", "result", - "$_self.cast().getElementType()">]> { + "$_self.cast().getElementType()">]> { let summary = "element extraction operation"; let description = [{ The `tensor.extract` op reads a ranked tensor and returns one element as @@ -281,8 +281,7 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic and - ShapedType::kDynamic encodes that the corresponding entry has + sentinel value ShapedType::kDynamic encodes that the corresponding entry has a dynamic value. After buffer allocation, the "extract_slice" op is expected to lower into a @@ -389,12 +388,12 @@ /// rank-reduced, from the source type and the static representation of /// offsets, sizes and strides. Special sentinels encode the dynamic case. static RankedTensorType inferResultType( - ShapedType sourceShapedTensorType, + RankedTensorType sourceTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); static RankedTensorType inferResultType( - ShapedType sourceShapedTensorType, + RankedTensorType sourceTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); @@ -459,8 +458,8 @@ Pure, TypesMatchWith<"operand types match result element type", "result", "elements", "SmallVector(" - "$_self.cast().getNumElements(), " - "$_self.cast().getElementType())"> + "$_self.cast().getNumElements(), " + "$_self.cast().getElementType())"> ]> { let summary = "tensor from elements operation."; let description = [{ @@ -695,7 +694,7 @@ "$_self">, TypesMatchWith<"scalar type matches element type of dest", "dest", "scalar", - "$_self.cast().getElementType()">]> { + "$_self.cast().getElementType()">]> { let summary = "element insertion operation"; let description = [{ The `tensor.insert` op inserts a scalar into a ranked tensor `dest` as @@ -770,8 +769,7 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic and - ShapedType::kDynamic encodes that the corresponding entry has + sentinel value ShapedType::kDynamic encodes that the corresponding entry has a dynamic value. After buffer allocation, the "insert_slice" op is expected to lower into a @@ -1381,8 +1379,7 @@ The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special - sentinel value ShapedType::kDynamic and - ShapedType::kDynamic encodes that the corresponding entry has + sentinel value ShapedType::kDynamic encodes that the corresponding entry has a dynamic value. After buffer allocation, the "parallel_insert_slice" op is expected to lower @@ -1790,10 +1787,10 @@ ArrayRef innerTileDims, ArrayRef innerDimsPos, ArrayRef outerDimsPerm = {}); - // Method to get the `ShapedType` of the result based on the inner tiles, - // position of the inner tiles (innerDimsPos) and interchange vector of - // outer loops (outerDimsPerm). - static ShapedType inferPackedType(ShapedType sourceType, + // Method to get the `RankedTensorType` of the result based on the inner + // tiles, position of the inner tiles (innerDimsPos) and interchange vector + // of outer loops (outerDimsPerm). + static RankedTensorType inferPackedType(RankedTensorType sourceType, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm = {}); diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -529,7 +529,7 @@ Vector_Op<"extractelement", [Pure, TypesMatchWith<"result type matches element type of vector operand", "vector", "result", - "$_self.cast().getElementType()">]>, + "$_self.cast().getElementType()">]>, Arguments<(ins AnyVectorOfAnyRank:$vector, Optional:$position)>, Results<(outs AnyType:$result)> { @@ -644,7 +644,7 @@ Vector_Op<"insertelement", [Pure, TypesMatchWith<"source operand type matches element type of result", "result", "source", - "$_self.cast().getElementType()">, + "$_self.cast().getElementType()">, AllTypesMatch<["dest", "result"]>]>, Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, Optional:$position)>, @@ -1884,23 +1884,15 @@ : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` }]; + let extraClassDeclaration = [{ - ShapedType getBaseType() { - return getBase().getType().cast(); - } - VectorType getIndexVectorType() { - return getIndexVec().getType().cast(); - } - VectorType getMaskVectorType() { - return getMask().getType().cast(); - } - VectorType getPassThruVectorType() { - return getPassThru().getType().cast(); - } - VectorType getVectorType() { - return getResult().getType().cast(); - } + ShapedType getBaseType() { return getBase().getType(); } + VectorType getIndexVectorType() { return getIndexVec().getType(); } + VectorType getMaskVectorType() { return getMask().getType(); } + VectorType getPassThruVectorType() { return getPassThru().getType(); } + VectorType getVectorType() { return getResult().getType(); } }]; + let assemblyFormat = "$base `[` $indices `]` `[` $index_vec `]` `,` " "$mask `,` $pass_thru attr-dict `:` type($base) `,` " @@ -1960,20 +1952,14 @@ : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> ``` }]; + let extraClassDeclaration = [{ - MemRefType getMemRefType() { - return getBase().getType().cast(); - } - VectorType getIndexVectorType() { - return getIndexVec().getType().cast(); - } - VectorType getMaskVectorType() { - return getMask().getType().cast(); - } - VectorType getVectorType() { - return getValueToStore().getType().cast(); - } + MemRefType getMemRefType() { return getBase().getType(); } + VectorType getIndexVectorType() { return getIndexVec().getType(); } + VectorType getMaskVectorType() { return getMask().getType(); } + VectorType getVectorType() { return getValueToStore().getType(); } }]; + let assemblyFormat = "$base `[` $indices `]` `[` $index_vec `]` `,` " "$mask `,` $valueToStore attr-dict `:` type($base) `,` " diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -50,9 +50,8 @@ `getArrayAttrMaxRanks()`[0] (resp. [1], [2]). 3. if an entry of `static_offsets` (resp. `static_sizes`, `static_strides`) is equal to a special sentinel value, namely - `ShapedType::kDynamic` (resp. `ShapedType::kDynamic`, - `ShapedType::kDynamic`), then the corresponding entry is - a dynamic offset (resp. size, stride). + `ShapedType::kDynamic`, then the corresponding entry is a dynamic + offset (resp. size, stride). 4. a variadic `offset` (resp. `sizes`, `strides`) operand must be present for each dynamic offset (resp. size, stride). 5. `offsets`, `sizes` and `strides` operands are specified in this order diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -128,7 +128,7 @@ const DataLayout *defaultLayout) const { uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout); for (unsigned i = 0, e = type.getRank(); i < e; i++) { - if (ShapedType::isDynamic(type.getDimSize(i))) + if (type.isDynamicDim(i)) continue; sizeDivisor = sizeDivisor * type.getDimSize(i); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -154,10 +154,10 @@ auto computeNumElements = [&](MemRefType type, function_ref getDynamicSize) -> Value { // Compute number of elements. - int64_t size = type.getShape()[0]; - Value numElements = ((size == ShapedType::kDynamic) - ? getDynamicSize() - : createIndexConstant(rewriter, loc, size)); + Value numElements = + type.isDynamicDim(0) + ? getDynamicSize() + : createIndexConstant(rewriter, loc, type.getDimSize(0)); Type indexType = getIndexType(); if (numElements.getType() != indexType) numElements = typeConverter->materializeTargetConversion( @@ -987,7 +987,7 @@ auto targetType = op.getTarget().getType().cast(); // First make sure we have an unranked memref descriptor representation. - auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) { + auto makeUnranked = [&, this](Value ranked, MemRefType type) { auto rank = rewriter.create(loc, getIndexType(), type.getRank()); auto *typeConverter = getTypeConverter(); @@ -1011,12 +1011,14 @@ auto stackSaveOp = rewriter.create(loc, getVoidPtrType()); - Value unrankedSource = srcType.hasRank() - ? makeUnranked(adaptor.getSource(), srcType) - : adaptor.getSource(); - Value unrankedTarget = targetType.hasRank() - ? makeUnranked(adaptor.getTarget(), targetType) - : adaptor.getTarget(); + auto srcMemRefType = srcType.dyn_cast(); + Value unrankedSource = + srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) + : adaptor.getSource(); + auto targetMemRefType = targetType.dyn_cast(); + Value unrankedTarget = + targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) + : adaptor.getTarget(); // Now promote the unranked descriptors to the stack. auto one = rewriter.create(loc, getIndexType(), @@ -1390,11 +1392,11 @@ } Value dimSize; - int64_t size = targetMemRefType.getDimSize(i); // If the size of this dimension is dynamic, then load it at runtime // from the shape operand. - if (!ShapedType::isDynamic(size)) { - dimSize = createIndexConstant(rewriter, loc, size); + if (!targetMemRefType.isDynamicDim(i)) { + dimSize = createIndexConstant(rewriter, loc, + targetMemRefType.getDimSize(i)); } else { Value shapeOp = reshapeOp.getShape(); Value index = createIndexConstant(rewriter, loc, i); @@ -1589,7 +1591,8 @@ return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); auto targetMemRef = MemRefDescriptor::undef( - rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); + rewriter, loc, + typeConverter->convertType(transposeOp.getIn().getType())); // Copy the base and aligned pointers from the old descriptor to the new // one. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1178,7 +1178,7 @@ } // The shape of each input must match the shape of the output. - auto outputShape = getInit().getType().cast().getShape(); + auto outputShape = getInit().getType().getShape(); for (Type inputArgType : TypeRange{getInputs()}) { auto inputElemShape = inputArgType.cast().getShape(); if (inputElemShape != outputShape) { diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3236,7 +3236,7 @@ LogicalResult TransposeOp::verify() { if (!getPermutation().isPermutation()) return emitOpError("expected a permutation map"); - if (getPermutation().getNumDims() != getShapedType().getRank()) + if (getPermutation().getNumDims() != getIn().getType().getRank()) return emitOpError("expected a permutation map of same rank as the input"); auto srcType = getIn().getType().cast(); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1617,27 +1617,26 @@ /// rank-reduced, from the source type and the static representation of /// offsets, sizes and strides. Special sentinels encode the dynamic case. RankedTensorType ExtractSliceOp::inferResultType( - ShapedType sourceShapedTensorType, ArrayRef staticOffsets, + RankedTensorType sourceTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides) { // An extract_slice op may specify only a leading subset of offset/sizes/ // strides in which case we complete with offset=0, sizes from memref type // and strides=1. assert(static_cast(staticSizes.size()) == - sourceShapedTensorType.getRank() && + sourceTensorType.getRank() && "unexpected staticSizes not equal to rank of source"); - return RankedTensorType::get(staticSizes, - sourceShapedTensorType.getElementType()); + return RankedTensorType::get(staticSizes, sourceTensorType.getElementType()); } RankedTensorType ExtractSliceOp::inferResultType( - ShapedType sourceShapedTensorType, ArrayRef offsets, + RankedTensorType sourceTensorType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { SmallVector staticOffsets, staticSizes, staticStrides; SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets, + return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets, staticSizes, staticStrides); } @@ -1756,22 +1755,21 @@ build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); } -template static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, - OpTy op, Type expectedType) { - auto memrefType = expectedType.cast(); + Operation *op, + RankedTensorType expectedType) { switch (result) { case SliceVerificationResult::Success: return success(); case SliceVerificationResult::RankTooLarge: - return op.emitError("expected rank to be smaller or equal to ") + return op->emitError("expected rank to be smaller or equal to ") << "the other rank. "; case SliceVerificationResult::SizeMismatch: - return op.emitError("expected type to be ") + return op->emitError("expected type to be ") << expectedType << " or a rank-reduced version. (size mismatch) "; case SliceVerificationResult::ElemTypeMismatch: - return op.emitError("expected element type to be ") - << memrefType.getElementType(); + return op->emitError("expected element type to be ") + << expectedType.getElementType(); default: llvm_unreachable("unexpected extract_slice op verification result"); } @@ -2147,9 +2145,9 @@ /// Rank-reducing type verification for both InsertSliceOp and /// ParallelInsertSliceOp. static SliceVerificationResult verifyInsertSliceOp( - ShapedType srcType, ShapedType dstType, ArrayRef staticOffsets, - ArrayRef staticSizes, ArrayRef staticStrides, - ShapedType *expectedType = nullptr) { + RankedTensorType srcType, RankedTensorType dstType, + ArrayRef staticOffsets, ArrayRef staticSizes, + ArrayRef staticStrides, RankedTensorType *expectedType = nullptr) { // insert_slice is the inverse of extract_slice, use the same type // inference. RankedTensorType expected = ExtractSliceOp::inferResultType( @@ -2161,7 +2159,7 @@ /// Verifier for InsertSliceOp. LogicalResult InsertSliceOp::verify() { - ShapedType expectedType; + RankedTensorType expectedType; SliceVerificationResult result = verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(), getStaticSizes(), getStaticStrides(), &expectedType); @@ -2334,8 +2332,10 @@ auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource()); auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest()); - auto srcType = src.getType().template cast(); - auto dstType = dst.getType().template cast(); + auto srcType = src.getType().template dyn_cast(); + auto dstType = dst.getType().template dyn_cast(); + if (!srcType || !dstType) + return failure(); if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(), insertSliceOp.getStaticSizes(), insertSliceOp.getStaticStrides()) != @@ -3072,7 +3072,7 @@ return this->emitError("expected ParallelCombiningOpInterface parent, got:") << *(getOperation()->getParentOp()); - ShapedType expectedType; + RankedTensorType expectedType; SliceVerificationResult result = verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(), getStaticSizes(), getStaticStrides(), &expectedType); @@ -3307,9 +3307,9 @@ return op->emitError("invalid zero tile factor"); // Verify inner_dims_pos and outer_dims_perm. - ShapedType unpackedType = (std::is_same::value) - ? packOrUnPack.getSourceType() - : packOrUnPack.getDestType(); + RankedTensorType unpackedType = (std::is_same::value) + ? packOrUnPack.getSourceType() + : packOrUnPack.getDestType(); size_t unpackedRank = unpackedType.getRank(); ArrayRef innerDimsPos = packOrUnPack.getInnerDimsPos(); ArrayRef outerDimPerm = packOrUnPack.getOuterDimsPerm(); @@ -3344,7 +3344,7 @@ // Verify result shape is greater than the minimum expected // by the pack operation, and that the output shape // represents full tiles. - ShapedType expectedPackedType = PackOp::inferPackedType( + RankedTensorType expectedPackedType = PackOp::inferPackedType( unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { return op->emitError("the shape of output is not large enough to hold the " @@ -3594,10 +3594,10 @@ /// Get the expected packed type based on source type, tile factors, position of /// the inner tiles and permutation of the outer tiled loop. -ShapedType PackOp::inferPackedType(ShapedType sourceType, - ArrayRef innerTileSizes, - ArrayRef innerDimsPos, - ArrayRef outerDimsPerm) { +RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { SmallVector resultShape = getPackOpResultTypeShape( sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); return RankedTensorType::get(resultShape, sourceType.getElementType());