diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -146,6 +146,8 @@ let results = (outs I1:$result); let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) attr-dict"; + + let hasFolder = 1; } def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -109,6 +109,7 @@ FloatAttr getF32FloatAttr(float value); FloatAttr getF64FloatAttr(double value); + IntegerAttr getI1IntegerAttr(int8_t value); IntegerAttr getI8IntegerAttr(int8_t value); IntegerAttr getI16IntegerAttr(int16_t value); IntegerAttr getI32IntegerAttr(int32_t value); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -380,6 +380,26 @@ OpFoldResult ConstWitnessOp::fold(ArrayRef) { return passingAttr(); } +//===----------------------------------------------------------------------===// +// EqOp +//===----------------------------------------------------------------------===// + +OpFoldResult EqOp::fold(ArrayRef operands) { + + // Fold size equality. + if (auto lhs = operands[0].dyn_cast_or_null()) { + if (auto rhs = operands[1].dyn_cast_or_null()) { + bool eq = lhs.getValue().eq(rhs.getValue()); + Builder builder(getContext()); + return builder.getI1IntegerAttr(eq); + } + } + + // TODO: Fold shape equality. + + return {}; +} + //===----------------------------------------------------------------------===// // IndexToSizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -165,6 +165,10 @@ return IntegerAttr::get(getIntegerType(8), APInt(8, value)); } +IntegerAttr Builder::getI1IntegerAttr(int8_t value) { + return IntegerAttr::get(getI1Type(), value); +} + IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) { if (type.isIndex()) return IntegerAttr::get(type, APInt(64, value)); diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -442,3 +442,44 @@ "consume.witness"(%0) : (!shape.witness) -> () return } + +// ----- + +// Fold `eq` for equal and constant sizes. +// CHECK-LABEL: @eq_size_fold_1 +func @eq_size_fold_1() -> i1 { + // CHECK: %[[RESULT:.*]] = constant true + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_size 123 + %b = shape.const_size 123 + %result = shape.eq %a, %b : !shape.size + return %result : i1 +} + +// ----- + +// Fold `eq` for different but constant sizes. +// CHECK-LABEL: @eq_size_fold_0 +func @eq_size_fold_0() -> i1 { + // CHECK: %[[RESULT:.*]] = constant false + // CHECK: return %[[RESULT]] : i1 + %a = shape.const_size 123 + %b = shape.const_size 456 + %result = shape.eq %a, %b : !shape.size + return %result : i1 +} + +// ----- + +// Do not fold `eq` for non-constant sizes. +// CHECK-LABEL: @eq_size_do_not_fold +// CHECK-SAME: (%[[A:.*]]: !shape.size) -> i1 +func @eq_size_do_not_fold(%a : !shape.size) -> i1 { + // CHECK: %[[B:.*]] = shape.const_size 456 + // CHECK: %[[RESULT:.*]] = shape.eq %[[A]], %[[B]] : !shape.size + // CHECK: return %[[RESULT]] : i1 + %b = shape.const_size 456 + %result = shape.eq %a, %b : !shape.size + return %result : i1 +} +