diff --git a/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandardPatterns.td b/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandardPatterns.td --- a/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandardPatterns.td +++ b/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandardPatterns.td @@ -1,6 +1,16 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/Dialect/StandardOps/IR/Ops.td" +// Rewrite patterns. + def GetExtentShapeOfConversion : Pat< (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx), (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx)))>; + +def GetExtentFromExtentTensorConversion : Pattern< + (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx), + [ + (Shape_SizeToIndexOp:$std_idx $idx), + (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)), + (Shape_IndexToSizeOp $std_result) + ]>; diff --git a/mlir/test/Dialect/Shape/canonicalize-shape-to-standard.mlir b/mlir/test/Dialect/Shape/canonicalize-shape-to-standard.mlir --- a/mlir/test/Dialect/Shape/canonicalize-shape-to-standard.mlir +++ b/mlir/test/Dialect/Shape/canonicalize-shape-to-standard.mlir @@ -14,4 +14,18 @@ return %result : !shape.size } +// ----- +// Express `get_extent` as `std.extract_element` when it relies directly on the +// outcome of a `from_extent_tensor` operation. +// CHECK-LABEL: @get_extent +// CHECK-SAME: (%[[EXTENTS:.*]]: tensor, %[[IDX:.*]]: !shape.size) -> !shape.size +func @get_extent(%extents : tensor, %idx : !shape.size) -> !shape.size { + // CHECK-DAG: %[[STD_IDX:.*]] = shape.size_to_index %[[IDX]] + // CHECK-DAG: %[[STD_RESULT:.*]] = extract_element %[[EXTENTS]][%[[STD_IDX]]] : tensor + // CHECK-DAG: %[[RESULT:.*]] = shape.index_to_size %[[STD_RESULT]] + // CHECK-DAG: return %[[RESULT]] : !shape.size + %shape = shape.from_extent_tensor %extents: tensor + %result = shape.get_extent %shape, %idx + return %result : !shape.size +}