diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -392,7 +392,7 @@ let assemblyFormat = "$inputs attr-dict"; } -def Shape_AssumingAllOp : Shape_Op<"assuming_all", [NoSideEffect]> { +def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]> { let summary = "Return a logical AND of all witnesses."; let description = [{ Used to simplify constraints as any single failing precondition is enough @@ -417,6 +417,8 @@ let results = (outs Shape_WitnessType:$result); let assemblyFormat = "$inputs attr-dict"; + + let hasCanonicalizer = 1; } def Shape_AssumingOp : Shape_Op<"assuming", diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS IR/ShapeCanonicalization.td) +mlir_tablegen(IR/ShapeCanonicalization.inc -gen-rewriters) +add_public_tablegen_target(MLIRShapeCanonicalizationIncGen) + add_mlir_dialect_library(MLIRShape IR/Shape.cpp diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -18,6 +18,9 @@ using namespace mlir; using namespace mlir::shape; +namespace { +#include "IR/ShapeCanonicalization.inc" +} ShapeDialect::ShapeDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< @@ -151,6 +154,15 @@ p.printOptionalAttrDict(op.getAttrs()); } +//===----------------------------------------------------------------------===// +// AssumingAllOp +//===----------------------------------------------------------------------===// +void AssumingAllOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + // If all true, replace with true. Folding? + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -0,0 +1,13 @@ +include "mlir/Dialect/Shape/IR/ShapeOps.td" + +// Constraints +def AllInputsTrueWitnesses : ConstraintgetOperands(), [](mlir::Value op) { + return op.getDefiningOp(); + }) + }]>>; + +// Canonicalization patterns. +def ConstantAssumingAll : Pat<(Shape_AssumingAllOp:$op $input), + (Shape_TrueWitnessOp), + [(AllInputsTrueWitnesses $op)] >; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -canonicalize <%s | FileCheck %s --dump-input=fail +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s --dump-input=fail // ----- // CHECK-LABEL: func @f @@ -106,3 +106,34 @@ %ret = shape.from_extents %e0, %arg0 return %ret : !shape.shape } + +// ----- +// assuming_all with known true witnesses can be folded +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.true_witness + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.true_witness + %1 = shape.true_witness + %2 = shape.true_witness + %3 = shape.assuming_all %0, %1, %2 + "consume.witness"(%3) : (!shape.witness) -> () + return +} + +// ----- +// assuming_all should not be removed if not all witnesses are statically true. +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: %[[TRUE:.*]] = shape.true_witness + // CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source" + // CHECK-NEXT: shape.assuming_all %[[TRUE]], %[[UNKNOWN]] + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.true_witness + %1 = "test.source"() : () -> !shape.witness + %2 = shape.assuming_all %0, %1 + "consume.witness"(%2) : (!shape.witness) -> () + return +}