diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -395,7 +395,23 @@ } } - // TODO: Fold shape equality. + // Fold shape equality. + if (auto lhs = operands[0].dyn_cast_or_null()) { + if (auto rhs = operands[1].dyn_cast_or_null()) { + auto iLhs = lhs.begin(), iRhs = rhs.begin(); + auto endLhs = lhs.end(), endRhs = rhs.end(); + auto sizeLhs = endLhs - iLhs, sizeRhs = endRhs - iRhs; + Builder builder(getContext()); + if (sizeLhs != sizeRhs) + return builder.getI1IntegerAttr(false); + while (iLhs < endLhs) { + assert(iRhs < endRhs); + if (*iLhs++ != *iRhs++) + return builder.getI1IntegerAttr(false); + } + return builder.getI1IntegerAttr(true); + } + } return {}; } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -483,3 +483,55 @@ return %result : i1 } +// ----- + +// Fold `eq` for equal and constant shapes. +// CHECK-LABEL: @eq_shape_fold_1 +func @eq_shape_fold_1() -> i1 { + // CHECK: %[[RESULT:.*]] = constant true + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_shape [1, 2, 3] + %b = shape.const_shape [1, 2, 3] + %result = shape.eq %a, %b : !shape.shape + return %result : i1 +} + +// ----- + +// Fold `eq` for different but constant shapes of same length. +// CHECK-LABEL: @eq_shape_fold_0 +func @eq_shape_fold_0() -> i1 { + // CHECK: %[[RESULT:.*]] = constant false + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_shape [1, 2, 3] + %b = shape.const_shape [4, 5, 6] + %result = shape.eq %a, %b : !shape.shape + return %result : i1 +} + +// ----- + +// Fold `eq` for different but constant shapes of different length. +// CHECK-LABEL: @eq_shape_fold_0 +func @eq_shape_fold_0() -> i1 { + // CHECK: %[[RESULT:.*]] = constant false + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_shape [1, 2, 3] + %b = shape.const_shape [1, 2, 3, 4, 5, 6] + %result = shape.eq %a, %b : !shape.shape + return %result : i1 +} + +// ----- + +// Do not fold `eq` for non-constant shapes. +// CHECK-LABEL: @eq_shape_do_not_fold +// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1 +func @eq_shape_do_not_fold(%a : !shape.shape) -> i1 { + // CHECK: %[[B:.*]] = shape.const_shape [4, 5, 6] + // CHECK: %[[RESULT:.*]] = shape.eq %[[A]], %[[B]] : !shape.shape + // CHECK: return %[[RESULT]] : i1 + %b = shape.const_shape [4, 5, 6] + %result = shape.eq %a, %b : !shape.shape + return %result : i1 +}