diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -105,6 +105,7 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Shape_ConstSizeOp : Shape_Op<"const_size", [ @@ -630,6 +631,7 @@ let assemblyFormat = "$inputs attr-dict"; let hasFolder = 1; + let hasCanonicalizer = 1; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -271,6 +271,15 @@ //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// + +void AssumingAllOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + // Canonicalization patterns have overlap with the considerations during + // folding in case additional shape information is inferred at some point that + // does not result in folding. + patterns.insert(context); +} + OpFoldResult AssumingAllOp::fold(ArrayRef operands) { // Iterate in reverse to first handle all constant operands. They are // guaranteed to be the tail of the inputs because this is commutative. @@ -394,6 +403,14 @@ OpFoldResult ConstShapeOp::fold(ArrayRef) { return shapeAttr(); } +void ConstShapeOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + // Canonicalization patterns have overlap with the considerations during + // folding in case additional shape information is inferred at some point that + // does not result in folding. + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // CstrBroadcastableOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -1,4 +1,5 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" def AllInputShapesEq : Constraint>; +def HasSingleElement : Constraint>; + // Canonicalization patterns. +def AssumingAllOneOp : Pat<(Shape_AssumingAllOp:$op $x), (replaceWithValue $x), + [(HasSingleElement $x)]>; + def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x), (Shape_ConstWitnessOp ConstBoolAttrTrue)>; @@ -23,3 +31,5 @@ (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)), (replaceWithValue $arg)>; +def TensorCastConstShape : Pat < + (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -433,14 +433,16 @@ // commutative. // CHECK-LABEL: func @f func @f() { - // CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source" - // CHECK-NEXT: shape.assuming_all %[[UNKNOWN]] + // CHECK-NEXT: %[[UNKNOWN1:.*]] = "test.source" + // CHECK-NEXT: %[[UNKNOWN2:.*]] = "test.source" + // CHECK-NEXT: shape.assuming_all %[[UNKNOWN1]], %[[UNKNOWN2]] // CHECK-NEXT: consume.witness // CHECK-NEXT: return %0 = shape.const_witness true %1 = "test.source"() : () -> !shape.witness - %2 = shape.assuming_all %0, %1 - "consume.witness"(%2) : (!shape.witness) -> () + %2 = "test.source"() : () -> !shape.witness + %3 = shape.assuming_all %0, %1, %2 + "consume.witness"(%3) : (!shape.witness) -> () return } @@ -854,3 +856,27 @@ %casted = shape.to_extent_tensor %arg : tensor -> tensor return %casted : tensor } + +// ----- + +// Fold assuming_all with a single input +// CHECK-LABEL: @fold_assuming_all_single_element +func @fold_assuming_all_single_element(%arg: tensor) { + // CHECK-NOT: assuming_all + %0 = "test.source"() : () -> (!shape.witness) + %1 = shape.assuming_all %0 + "consume.witness"(%1) : (!shape.witness) -> () + return +} +// ----- + +// Fold tensor_cast of a const_shape to const_shape +// CHECK-LABEL: @fold_tensor_cast_of_const_shape +func @fold_tensor_cast_of_const_shape(%arg: tensor) { + // CHECK-NOT: tensor_cast + %0 = shape.const_shape [2] : tensor + %1 = tensor_cast %0 : tensor to tensor<1xindex> + %2 = shape.cstr_broadcastable %1, %0 : tensor<1xindex>, tensor + "consume.witness"(%2) : (!shape.witness) -> () + return +}