diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -161,6 +161,7 @@ let results = (outs I1:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; + let hasFolder = 1; } def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -545,6 +545,21 @@ return BoolAttr::get(lhs == rhs, getContext()); } +//===----------------------------------------------------------------------===// +// SizeEqOp +//===----------------------------------------------------------------------===// + +OpFoldResult SizeEqOp::fold(ArrayRef operands) { + auto lhs = operands[0].dyn_cast_or_null(); + if (lhs == nullptr) + return {}; + auto rhs = operands[1].dyn_cast_or_null(); + if (rhs == nullptr) + return {}; + bool equal = lhs.getValue().eq(rhs.getValue()); + return BoolAttr::get(equal, getContext()); +} + //===----------------------------------------------------------------------===// // IndexToSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -721,6 +721,44 @@ // ----- +// Fold `size_eq` for equal and constant sizes. +// CHECK-LABEL: @size_eq_fold_1 +func @size_eq_fold_1() -> i1 { + // CHECK: %[[RESULT:.*]] = constant true + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_size 123 + %b = shape.const_size 123 + %result = shape.size_eq %a, %b : !shape.size, !shape.size + return %result : i1 +} + +// ----- + +// Fold `size_eq` for different but constant sizes. +// CHECK-LABEL: @size_eq_fold_0 +func @size_eq_fold_0() -> i1 { + // CHECK: %[[RESULT:.*]] = constant false + // CHECK: return %[[RESULT]] : i1 + %a = constant 123 : index + %b = shape.const_size 456 + %result = shape.size_eq %a, %b : index, !shape.size + return %result : i1 +} + +// ----- + +// Do not fold `size_eq` for non-constant sizes. +// CHECK-LABEL: @size_eq_do_not_fold +func @size_eq_do_not_fold(%a : !shape.size) -> i1 { + // CHECK: %[[RESULT:.*]] = shape.size_eq + // CHECK: return %[[RESULT]] : i1 + %b = shape.const_size 456 + %result = shape.size_eq %a, %b : !shape.size, !shape.size + return %result : i1 +} + +// ----- + // Fold `shape_eq` for equal and constant shapes. // CHECK-LABEL: @shape_eq_fold_1 func @shape_eq_fold_1() -> i1 {