diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -130,6 +130,10 @@ let arguments = (ins IndexAttr:$value); let results = (outs Shape_SizeType:$result); + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, int64_t value"> + ]; + let assemblyFormat = "$value attr-dict"; let hasFolder = 1; } @@ -181,6 +185,7 @@ let assemblyFormat = "attr-dict $shape"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -359,6 +359,12 @@ // ConstSizeOp //===----------------------------------------------------------------------===// +void ConstSizeOp::build(OpBuilder &builder, OperationState &result, + int64_t value) { + auto attr = builder.getIndexAttr(value); + build(builder, result, attr); +} + OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } void ConstSizeOp::getAsmResultNames( @@ -445,6 +451,32 @@ return builder.getIndexAttr(size); } +namespace { +struct GetSizeShapeOfCanonicalizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetSizeOp op, + PatternRewriter &rewriter) const override { + auto shapeOfOp = dyn_cast_or_null(op.shape().getDefiningOp()); + if (!shapeOfOp) + return failure(); + auto rankedTensorType = + shapeOfOp.arg().getType().dyn_cast(); + if (!rankedTensorType) + return failure(); + int64_t shapeSize = rankedTensorType.getRank(); + rewriter.replaceOpWithNewOp(op.getOperation(), shapeSize); + return success(); + } +}; +} // namespace + +void GetSizeOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -478,3 +478,29 @@ %size = shape.get_size %shape return %size : !shape.size } + +// ----- + +// Canonicalize `get_size` when shape is derived from ranked tensor. +// CHECK-LABEL: @canonicalize_size +func @canonicalize_size(%arg : tensor<1x2x?xf32>) -> !shape.size { +// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3 +// CHECK-DAG: return %[[RESULT]] : !shape.size +%shape = shape.shape_of %arg : tensor<1x2x?xf32> +%size = shape.get_size %shape +return %size : !shape.size +} + +// ----- + +// Do not canonicalize `get_size` when shape is derived from unranked tensor. +// CHECK-LABEL: @dont_canonicalize_size +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size +func @dont_canonicalize_size(%arg : tensor<*xf32>) -> !shape.size { +// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> +// CHECK-DAG: %[[SIZE:.*]] = shape.get_size %[[SHAPE]] +// CHECK-DAG: return %[[SIZE]] : !shape.size +%shape = shape.shape_of %arg : tensor<*xf32> +%size = shape.get_size %shape +return %size : !shape.size +}