diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -130,6 +130,10 @@ let arguments = (ins IndexAttr:$value); let results = (outs Shape_SizeType:$result); + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, int64_t value"> + ]; + let assemblyFormat = "$value attr-dict"; let hasFolder = 1; } @@ -181,6 +185,7 @@ let assemblyFormat = "attr-dict $shape"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -364,6 +364,11 @@ // ConstSizeOp //===----------------------------------------------------------------------===// +void ConstSizeOp::build(OpBuilder &builder, OperationState &result, + int64_t value) { + build(builder, result, builder.getIndexAttr(value)); +} + OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } void ConstSizeOp::getAsmResultNames( @@ -450,6 +455,45 @@ return builder.getIndexAttr(rank); } +// Evaluate the `rank` operation for shapes of ranked tensors at compile time. +// Constant folding fails in cases where only the rank is constant, not the +// shape itself. +// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. +// +// Example: +// +// %shape = shape.shape_of %ranked_tensor tensor<1x2x?xf32> +// %rank = shape.rank %shape +// +// becomes +// +// %rank = shape.const_size 3 + +namespace { +struct RankShapeOfCanonicalizationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RankOp op, + PatternRewriter &rewriter) const override { + auto shapeOfOp = op.shape().getDefiningOp(); + if (!shapeOfOp) + return failure(); + auto rankedTensorType = + shapeOfOp.arg().getType().dyn_cast(); + if (!rankedTensorType) + return failure(); + int64_t rank = rankedTensorType.getRank(); + rewriter.replaceOpWithNewOp(op.getOperation(), rank); + return success(); + } +}; +} // namespace + +void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -466,3 +466,29 @@ %rank = shape.rank %shape return %rank : !shape.size } + +// ----- + +// Canonicalize `rank` when shape is derived from ranked tensor. +// CHECK-LABEL: @canonicalize_rank +func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size { +// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3 +// CHECK-DAG: return %[[RESULT]] : !shape.size +%shape = shape.shape_of %arg : tensor<1x2x?xf32> +%rank = shape.rank %shape +return %rank : !shape.size +} + +// ----- + +// Do not canonicalize `rank` when shape is derived from unranked tensor. +// CHECK-LABEL: @dont_canonicalize_rank +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size +func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size { +// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> +// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]] +// CHECK-DAG: return %[[SIZE]] : !shape.size +%shape = shape.shape_of %arg : tensor<*xf32> +%rank = shape.rank %shape +return %rank : !shape.size +}