diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -572,6 +572,8 @@ //===----------------------------------------------------------------------===// OpFoldResult ShapeEqOp::fold(ArrayRef operands) { + if (lhs() == rhs()) + return BoolAttr::get(true, getContext()); auto lhs = operands[0].dyn_cast_or_null(); if (lhs == nullptr) return {}; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -787,7 +787,7 @@ // ----- -// Do not fold `shape_eq` for non-constant shapes. +// Do not fold `shape_eq` for non-constant different shapes. // CHECK-LABEL: @shape_eq_do_not_fold // CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1 func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 { @@ -799,6 +799,19 @@ return %result : i1 } + +// ----- + +// Fold `shape_eq` for non-constant but same shapes. +// CHECK-LABEL: @shape_eq_do_fold +// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1 +func @shape_eq_do_fold(%a : !shape.shape) -> i1 { + // CHECK: %[[RESULT:.*]] = constant true + // CHECK: return %[[RESULT]] : i1 + %result = shape.shape_eq %a, %a : !shape.shape, !shape.shape + return %result : i1 +} + // ----- // Fold `mul` for constant sizes.