diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -122,4 +122,6 @@ Shape_ExtentTensorType], "shape or extent tensor">; +def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">; + #endif // SHAPE_BASE_TD diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -201,12 +201,13 @@ }]; let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); - let results = (outs Shape_SizeType:$rank); + let results = (outs Shape_SizeOrIndexType:$rank); - let assemblyFormat = "$shape `:` type($shape) attr-dict"; + let assemblyFormat = "$shape `:` type($shape) `->` type($rank) attr-dict"; let hasFolder = 1; let hasCanonicalizer = 1; + let verifier = [{ return ::verify(*this); }]; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -52,6 +53,8 @@ return builder.create(loc, type, value.cast()); if (type.isa()) return builder.create(loc, type, value.cast()); + if (type.isa()) + return builder.create(loc, type, value); return nullptr; } @@ -563,7 +566,17 @@ // RankOp //===----------------------------------------------------------------------===// -OpFoldResult RankOp::fold(ArrayRef operands) { +static LogicalResult verify(shape::RankOp op) { + Type argTy = op.shape().getType(); + Type resultTy = op.rank().getType(); + if (argTy.isa() && !resultTy.isa()) + return op.emitOpError() + << "if operand is of type `shape` then the result must be of type " + "`size` to propagate potential errors"; + return success(); +} + +OpFoldResult shape::RankOp::fold(ArrayRef operands) { auto shape = operands[0].dyn_cast_or_null(); if (!shape) return {}; @@ -587,10 +600,11 @@ /// %rank = shape.const_size 3 namespace { -struct RankShapeOfCanonicalizationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct RankShapeOfCanonicalizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(RankOp op, + LogicalResult matchAndRewrite(shape::RankOp op, PatternRewriter &rewriter) const override { auto shapeOfOp = op.shape().getDefiningOp(); if (!shapeOfOp) @@ -599,15 +613,18 @@ shapeOfOp.arg().getType().dyn_cast(); if (!rankedTensorType) return failure(); + assert(op.getType().isa() && + "expected `rank(shape_of( ... )]` based on a shaped argument to " + "yield an index type"); int64_t rank = rankedTensorType.getRank(); - rewriter.replaceOpWithNewOp(op.getOperation(), rank); + rewriter.replaceOpWithNewOp(op.getOperation(), rank); return success(); } }; } // namespace -void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, - MLIRContext *context) { +void shape::RankOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -122,12 +122,12 @@ // Convert `rank` to `dim` of the first dimension. // CHECK-LABEL: @rank // CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> index -func @rank(%shape : !shape.shape) -> !shape.size { - // CHECK-DAG: %[[C0:.*]] = constant 0 : index - // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]] - // CHECK-DAG: return %[[RESULT]] : index - %rank = shape.rank %shape : !shape.shape - return %rank : !shape.size +func @rank(%shape : tensor) -> index { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]] + // CHECK: return %[[RESULT]] : index + %rank = shape.rank %shape : tensor -> index + return %rank : index } // ----- diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -496,10 +496,10 @@ // Fold `rank` based on constant shape. // CHECK-LABEL: @fold_rank func @fold_rank() -> !shape.size { - // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5 - // CHECK-DAG: return %[[RESULT]] : !shape.size + // CHECK: %[[RESULT:.*]] = shape.const_size 5 + // CHECK: return %[[RESULT]] : !shape.size %shape = shape.const_shape [3, 4, 5, 6, 7] : !shape.shape - %rank = shape.rank %shape : !shape.shape + %rank = shape.rank %shape : !shape.shape -> !shape.size return %rank : !shape.size } @@ -509,38 +509,64 @@ // CHECK-LABEL: @dont_fold_rank // CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) -> !shape.size func @dont_fold_rank(%shape : !shape.shape) -> !shape.size { - // CHECK-DAG: %[[RESULT:.*]] = shape.rank %[[SHAPE]] - // CHECK-DAG: return %[[RESULT]] : !shape.size - %rank = shape.rank %shape : !shape.shape + // CHECK: %[[RESULT:.*]] = shape.rank %[[SHAPE]] + // CHECK: return %[[RESULT]] : !shape.size + %rank = shape.rank %shape : !shape.shape -> !shape.size return %rank : !shape.size } // ----- +// Fold `rank` based on constant extent tensor. +// CHECK-LABEL: @fold_rank +func @fold_rank() -> index { + // CHECK: %[[RESULT:.*]] = constant 5 : index + // CHECK: return %[[RESULT]] : index + %shape = shape.const_shape [3, 4, 5, 6, 7] : tensor + %rank = shape.rank %shape : tensor -> index + return %rank : index +} + +// ----- + +// Do not fold `rank` for non-constant extent tensors. +// CHECK-LABEL: @dont_fold_rank +// CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> index +func @dont_fold_rank(%shape : tensor) -> index { + // CHECK: %[[RESULT:.*]] = shape.rank %[[SHAPE]] : tensor -> index + // CHECK: return %[[RESULT]] : index + %rank = shape.rank %shape : tensor -> index + return %rank : index +} + +// ----- + // Canonicalize `rank` when shape is derived from ranked tensor. // CHECK-LABEL: @canonicalize_rank -func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size { - // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3 - // CHECK-DAG: return %[[RESULT]] : !shape.size +func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> index { + // CHECK: %[[RESULT:.*]] = constant 3 : index + // CHECK: return %[[RESULT]] : index %shape = shape.shape_of %arg : tensor<1x2x?xf32> -> tensor - %rank = shape.rank %shape : tensor - return %rank : !shape.size + %rank = shape.rank %shape : tensor -> index + return %rank : index } // ----- // Do not canonicalize `rank` when shape is derived from unranked tensor. // CHECK-LABEL: @dont_canonicalize_rank -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size -func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size { - // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor - // CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]] - // CHECK-DAG: return %[[SIZE]] : !shape.size +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> index +func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> index { + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor + // CHECK: %[[SIZE:.*]] = shape.rank %[[SHAPE]] + // CHECK: return %[[SIZE]] : index %shape = shape.shape_of %arg : tensor<*xf32> -> tensor - %rank = shape.rank %shape : tensor - return %rank : !shape.size + %rank = shape.rank %shape : tensor -> index + return %rank : index } +// ----- + // Canonicalize redundant conversion from `index` to `size` and back. // CHECK-LABEL: @index_to_size_to_index // CHECK-SAME: (%[[IDX:.*]]: index) -> index diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -95,3 +95,10 @@ %1 = shape.shape_of %shaped_arg : tensor -> !shape.shape } +// ----- + +func @rank(%arg : !shape.shape) { + // expected-error@+1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}} + %0 = shape.rank %arg : !shape.shape -> index +} + diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -137,13 +137,13 @@ } func @rank(%shape : !shape.shape) -> !shape.size { - %rank = shape.rank %shape : !shape.shape + %rank = shape.rank %shape : !shape.shape -> !shape.size return %rank : !shape.size } -func @rank_on_extent_tensor(%shape : tensor) -> !shape.size { - %rank = shape.rank %shape : tensor - return %rank : !shape.size +func @rank_on_extent_tensor(%shape : tensor) -> index { + %rank = shape.rank %shape : tensor -> index + return %rank : index }