diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -151,6 +151,7 @@ let results = (outs I1:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict"; + let hasFolder = 1; } def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -385,6 +385,21 @@ OpFoldResult ConstWitnessOp::fold(ArrayRef) { return passingAttr(); } +//===----------------------------------------------------------------------===// +// SizeEqOp +//===----------------------------------------------------------------------===// + +OpFoldResult SizeEqOp::fold(ArrayRef operands) { + auto lhs = operands[0].dyn_cast_or_null(); + if (lhs == nullptr) + return {}; + auto rhs = operands[1].dyn_cast_or_null(); + if (rhs == nullptr) + return {}; + bool equal = lhs.getValue().eq(rhs.getValue()); + return BoolAttr::get(equal, getContext()); +} + //===----------------------------------------------------------------------===// // IndexToSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -492,3 +492,44 @@ %rank = shape.rank %shape return %rank : !shape.size } + +// ----- + +// Fold `size_eq` for equal and constant sizes. +// CHECK-LABEL: @size_eq_fold_1 +func @size_eq_fold_1() -> i1 { + // CHECK: %[[RESULT:.*]] = constant true + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_size 123 + %b = shape.const_size 123 + %result = shape.size_eq %a, %b + return %result : i1 +} + +// ----- + +// Fold `size_eq` for different but constant sizes. +// CHECK-LABEL: @size_eq_fold_0 +func @size_eq_fold_0() -> i1 { + // CHECK: %[[RESULT:.*]] = constant false + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_size 123 + %b = shape.const_size 456 + %result = shape.size_eq %a, %b + return %result : i1 +} + +// ----- + +// Do not fold `size_eq` for non-constant sizes. +// CHECK-LABEL: @size_eq_do_not_fold +// CHECK-SAME: (%[[A:.*]]: !shape.size) -> i1 +func @size_eq_do_not_fold(%a : !shape.size) -> i1 { + // CHECK: %[[B:.*]] = shape.const_size 456 + // CHECK: %[[RESULT:.*]] = shape.size_eq %[[A]], %[[B]] + // CHECK: return %[[RESULT]] : i1 + %b = shape.const_size 456 + %result = shape.size_eq %a, %b + return %result : i1 +} +