diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td @@ -26,3 +26,13 @@ (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))), [], (addBenefit 10)>; +def GetExtentFromExtentTensorConversion : Pattern< + (Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx), + [ + (Shape_SizeToIndexOp:$std_idx $idx), + (ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)), + (Shape_IndexToSizeOp $std_result) + ], + [], + (addBenefit 10)>; + diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -143,3 +143,18 @@ return %result : !shape.size } +// ----- + +// Express `get_extent` as `std.extract_element` when it relies directly on the +// outcome of a `from_extent_tensor` operation. +// CHECK-LABEL: @get_extent_from_extent_tensor +// CHECK-SAME: (%[[EXTENTS:.*]]: tensor, %[[IDX:.*]]: index) -> index +func @get_extent_from_extent_tensor(%extents : tensor, + %idx : !shape.size) -> !shape.size { + // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor + // CHECK: return %[[RESULT]] : index + %shape = shape.from_extent_tensor %extents : tensor + %result = shape.get_extent %shape, %idx + return %result : !shape.size +} +