diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -139,7 +139,7 @@ // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast(); + auto attrType = getValue().getType().cast(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -139,7 +139,7 @@ // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast(); + auto attrType = getValue().getType().cast(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -199,7 +199,7 @@ // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast(); + auto attrType = getValue().getType().cast(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -199,7 +199,7 @@ // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast(); + auto attrType = getValue().getType().cast(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -30,9 +30,8 @@ // ToyToAffine RewritePatterns //===----------------------------------------------------------------------===// -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { return MemRefType::get(type.getShape(), type.getElementType()); } @@ -63,7 +62,7 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { - auto tensorType = (*op->result_type_begin()).cast(); + auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. @@ -144,7 +143,7 @@ // When lowering the constant operation, we allocate and assign the constant // values to a corresponding memref allocation. - auto tensorType = op.getType().cast(); + auto tensorType = op.getType().cast(); auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -199,7 +199,7 @@ // Check that the rank of the attribute type matches the rank of the constant // result type. - auto attrType = getValue().getType().cast(); + auto attrType = getValue().getType().cast(); if (attrType.getRank() != resultType.getRank()) { return emitOpError("return type must match the one of the attached value " "attribute: ") diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -30,9 +30,8 @@ // ToyToAffine RewritePatterns //===----------------------------------------------------------------------===// -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { return MemRefType::get(type.getShape(), type.getElementType()); } @@ -63,7 +62,7 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { - auto tensorType = (*op->result_type_begin()).cast(); + auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. @@ -144,7 +143,7 @@ // When lowering the constant operation, we allocate and assign the constant // values to a corresponding memref allocation. - auto tensorType = op.getType().cast(); + auto tensorType = op.getType().cast(); auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -195,7 +195,7 @@ // Check that the rank of the attribute type matches the rank of the // constant result type. - auto attrType = attrValue.getType().cast(); + auto attrType = attrValue.getType().cast(); if (attrType.getRank() != resultType.getRank()) { return op->emitOpError("return type must match the one of the attached " "value attribute: ") diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -30,9 +30,8 @@ // ToyToAffine RewritePatterns //===----------------------------------------------------------------------===// -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); +/// Convert the given RankedTensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(RankedTensorType type) { return MemRefType::get(type.getShape(), type.getElementType()); } @@ -63,7 +62,7 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { - auto tensorType = (*op->result_type_begin()).cast(); + auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. @@ -144,7 +143,7 @@ // When lowering the constant operation, we allocate and assign the constant // values to a corresponding memref allocation. - auto tensorType = op.getType().cast(); + auto tensorType = op.getType().cast(); auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -974,8 +974,8 @@ Tensor_Op, Pure])>, - Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, - Results<(outs AnyTensor:$result)> { + Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>, + Results<(outs AnyRankedTensor:$result)> { code commonExtraClassDeclaration = [{ static StringRef getReassociationAttrStrName() { return "reassociation"; } @@ -1210,7 +1210,7 @@ }]; let arguments = (ins - AnyTensor:$source, + AnyRankedTensor:$source, Variadic:$low, Variadic:$high, DenseI64ArrayAttr:$static_low, @@ -1219,7 +1219,7 @@ let regions = (region SizedRegion<1>:$region); - let results = (outs AnyTensor:$result); + let results = (outs AnyRankedTensor:$result); // TODO: Remove custom when AllTypesMatch supports opt. operands. let assemblyFormat = [{ @@ -1678,8 +1678,8 @@ "$_self">])> { code commonExtraClassDeclaration = [{ - size_t getSourceRank() { return getSource().getType().getRank(); }; - size_t getDestRank() { return getDest().getType().getRank(); }; + size_t getSourceRank() { return getSourceType().getRank(); }; + size_t getDestRank() { return getDestType().getRank(); }; RankedTensorType getSourceType() { return getSource().getType().cast(); }; RankedTensorType getDestType() { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -240,6 +240,10 @@ def IsUnrankedTensorTypePred : CPred<"$_self.isa<::mlir::UnrankedTensorType>()">; +// Whether a type is a RankedTensorType +def IsRankedTensorTypePred + : CPred<"$_self.isa<::mlir::RankedTensorType>()">; + // Whether a type is a BaseMemRefType def IsBaseMemRefTypePred : CPred<"$_self.isa<::mlir::BaseMemRefType>()">; @@ -721,11 +725,21 @@ //===----------------------------------------------------------------------===// // Tensor types. -// Unranked tensor type whose element type is from the given -// `allowedTypes` list. -class UnrankedTensorOf allowedTypes> - : ShapedContainerType; +// Unranked tensor type whose element type is from the given `allowedTypes` +// list, and which additionally satisfies an optional list of predicates. +class UnrankedTensorOf allowedTypes, list preds = [], + string summary = "unranked tensor"> + : ShapedContainerType< + allowedTypes, And, + summary, "::mlir::UnrankedTensorType">; + +// Ranked tensor type whose element type is from the given `allowedTypes` list, +// and which additionally satisfies an optional list of predicates. +class RankedTensorOf allowedTypes, list preds = [], + string summary = "ranked tensor"> + : ShapedContainerType< + allowedTypes, And, + summary, "::mlir::RankedTensorType">; // Any tensor type whose element type is from the given `allowedTypes` // list, and which additionally satisfies an optional list of predicates. @@ -754,12 +768,6 @@ def F32Tensor : TensorOf<[F32]>; def F64Tensor : TensorOf<[F64]>; -class RankedTensorOf< - list allowedTypes, - list preds = [], - string summary = "ranked tensor"> - : TensorOf; - class Non0RankedTensorOf allowedTypes> : TensorOf], "non-0-ranked.tensor">; @@ -773,7 +781,7 @@ // Ranked tensor type with one of the specified types and ranks. class TensorRankOf allowedTypes, list ranks> - : TensorOf], !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -44,7 +44,7 @@ LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TensorType tensorType = extractOp.getTensor().getType().cast(); + auto tensorType = extractOp.getTensor().getType().cast(); if (!tensorType.hasStaticShape()) return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -369,11 +369,16 @@ auto extractOperand = tensorCast.getOperand().getDefiningOp(); + // Cannot fold cast to unranked tensor. + auto rankedResultType = tensorCast.getType().dyn_cast(); + if (!rankedResultType) + return failure(); + if (!extractOperand || !canFoldIntoProducerOp(tensorCast) || - tensorCast.getType().getShape() == tensorCast.getSource() - .getType() - .cast() - .getShape()) + rankedResultType.getShape() == tensorCast.getSource() + .getType() + .cast() + .getShape()) return failure(); SmallVector sizes = extractOperand.getMixedSizes(); @@ -383,15 +388,15 @@ for (size_t i = 0, e = sizes.size(); i < e; i++) { if (dimMask && dimMask->count(i)) continue; - int64_t dim = tensorCast.getType().getShape()[dimIndex++]; + int64_t dim = rankedResultType.getShape()[dimIndex++]; if (ShapedType::isDynamic(dim)) continue; sizes[i] = rewriter.getIndexAttr(dim); } rewriter.replaceOpWithNewOp( - tensorCast, tensorCast.getType().cast(), - extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes, + tensorCast, rankedResultType, extractOperand.getSource(), + extractOperand.getMixedOffsets(), sizes, extractOperand.getMixedStrides()); return success(); } @@ -1500,7 +1505,7 @@ return failure(); // Skip static dims. These are folded to constant ops. - TensorType resultType = expandShapeOp.getResultType(); + RankedTensorType resultType = expandShapeOp.getResultType(); if (!resultType.isDynamicDim(*dim)) return failure(); @@ -1544,7 +1549,7 @@ return failure(); // Skip static dims. These are folded to constant ops. - TensorType resultType = collapseShapeOp.getResultType(); + RankedTensorType resultType = collapseShapeOp.getResultType(); if (!resultType.isDynamicDim(*dim)) return failure(); diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -2,7 +2,7 @@ // Asking the dimension of a 0-D shape doesn't make sense. func.func @dim_0_ranked(%arg : tensor, %arg1 : index) { - tensor.dim %arg, %arg1 : tensor // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor'}} + tensor.dim %arg, %arg1 : tensor // expected-error {{'tensor.dim' op operand #0 must be unranked tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor'}} return }