diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -230,6 +230,7 @@ let assemblyFormat = "$arg attr-dict"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Shape_JoinOp : Shape_Op<"join", []> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -392,6 +392,11 @@ return {}; } +void IndexToSizeOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // FromExtentsOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -22,3 +22,7 @@ (Shape_SizeToIndexOp (Shape_IndexToSizeOp $arg)), (replaceWithValue $arg)>; +def SizeToIndexToSizeCanonicalization : Pat< + (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)), + (replaceWithValue $arg)>; + diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -455,3 +455,15 @@ return %result : index } +// ----- + +// Canonicalize redundant conversion from `size` to `index` and back. +// CHECK-LABEL: @size_to_index_to_size +// CHECK-SAME: (%[[SIZE:.*]]: !shape.size) -> !shape.size +func @size_to_index_to_size(%size : !shape.size) -> !shape.size { + // CHECK: return %[[SIZE]] : !shape.size + %idx = shape.size_to_index %size + %result = shape.index_to_size %idx + return %result : !shape.size +} +