diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -150,6 +150,7 @@ let results = (outs I1:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; + let hasFolder = 1; } def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -465,6 +465,20 @@ OpFoldResult ConstWitnessOp::fold(ArrayRef) { return passingAttr(); } +//===----------------------------------------------------------------------===// +// ShapeEqOp +//===----------------------------------------------------------------------===// + +OpFoldResult ShapeEqOp::fold(ArrayRef operands) { + auto lhs = operands[0].dyn_cast_or_null(); + if (lhs == nullptr) + return {}; + auto rhs = operands[1].dyn_cast_or_null(); + if (rhs == nullptr) + return {}; + return BoolAttr::get(lhs == rhs, getContext()); +} + //===----------------------------------------------------------------------===// // IndexToSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -561,3 +561,56 @@ "consume.witness"(%2) : (!shape.witness) -> () return } + +// ----- + +// Fold `shape_eq` for equal and constant shapes. +// CHECK-LABEL: @shape_eq_fold_1 +func @shape_eq_fold_1() -> i1 { + // CHECK: %[[RESULT:.*]] = constant true + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_shape [1, 2, 3] + %b = shape.const_shape [1, 2, 3] + %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape + return %result : i1 +} + +// ----- + +// Fold `shape_eq` for different but constant shapes of same length. +// CHECK-LABEL: @shape_eq_fold_0 +func @shape_eq_fold_0() -> i1 { + // CHECK: %[[RESULT:.*]] = constant false + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_shape [1, 2, 3] + %b = shape.const_shape [4, 5, 6] + %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape + return %result : i1 +} + +// ----- + +// Fold `shape_eq` for different but constant shapes of different length. +// CHECK-LABEL: @shape_eq_fold_0 +func @shape_eq_fold_0() -> i1 { + // CHECK: %[[RESULT:.*]] = constant false + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_shape [1, 2, 3, 4, 5, 6] + %b = shape.const_shape [1, 2, 3] + %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape + return %result : i1 +} + +// ----- + +// Do not fold `shape_eq` for non-constant shapes. +// CHECK-LABEL: @shape_eq_do_not_fold +// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1 +func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 { + // CHECK: %[[B:.*]] = shape.const_shape [4, 5, 6] + // CHECK: %[[RESULT:.*]] = shape.shape_eq %[[A]], %[[B]] : !shape.shape, !shape.shape + // CHECK: return %[[RESULT]] : i1 + %b = shape.const_shape [4, 5, 6] + %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape + return %result : i1 +}