diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1473,7 +1473,8 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -16,6 +16,8 @@ $0.getType().dyn_cast().hasStaticShape() }]>>; +def TakeFront : NativeCodeCall<"$0.front()">; + // Canonicalization patterns. def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args), @@ -43,3 +45,8 @@ def TensorCastConstShape : Pat < (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg), [(HasStaticShape $res)]>; + +// tensor.extract from shape_of -> tensor.dim +def ExtractFromShapeOfExtentTensor : Pat< + (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices), + (Tensor_DimOp $arg, (TakeFront $indices))>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1380,3 +1380,17 @@ -> tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @extract_shapeof +// CHECK-SAME: %[[ARG0:.*]]: tensor +func @extract_shapeof(%arg0 : tensor) -> index { + %c0 = constant 1 : index +// CHECK: %[[C1:.*]] = constant 1 + %shape = shape.shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]] + %result = tensor.extract %shape[%c0] : tensor<2xindex> +// CHECK: return %[[DIM]] + return %result : index +}