diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -392,6 +392,7 @@ let assemblyFormat = "$arg attr-dict"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Shape_YieldOp : Shape_Op<"yield", diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -536,6 +536,11 @@ return {}; } +void SizeToIndexOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -9,6 +9,7 @@ }]>>; // Canonicalization patterns. + def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs), (Shape_ConstWitnessOp ConstBoolAttrTrue), [(EqualBinaryOperands $lhs, $rhs)]>; @@ -16,3 +17,8 @@ def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes), (Shape_ConstWitnessOp ConstBoolAttrTrue), [(AllInputShapesEq $shapes)]>; + +def IndexToSizeToIndexCanonicalization : Pat< + (Shape_SizeToIndexOp (Shape_IndexToSizeOp $arg)), + (replaceWithValue $arg)>; + diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -492,3 +492,14 @@ %rank = shape.rank %shape return %rank : !shape.size } + +// Canonicalize redundant conversion from `index` to `size` and back. +// CHECK-LABEL: @index_to_size_to_index +// CHECK-SAME: (%[[IDX:.*]]: index) -> index +func @index_to_size_to_index(%index : index) -> index { + // CHECK: return %[[IDX]] : index + %size = shape.index_to_size %index + %result = shape.size_to_index %size + return %result : index +} +