diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -168,20 +168,37 @@ let hasFolder = 1; } -def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> { +def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative, + InferTypeOpInterface]> { let summary = "Returns whether the input shapes or extent tensors are equal"; let description = [{ - Takes two shape or extent tensor operands and determines whether they are - equal. When extent tensors are compared to shapes they are regarded as their - equivalent non-error shapes. Error shapes can be tested for equality like - any other shape value, meaning that the error value is equal to itself. + Takes one or more shape or extent tensor operands and determines whether + they are equal. When extent tensors are compared to shapes they are regarded + as their equivalent non-error shapes. Error shapes can be tested for + equality like any other shape value, meaning that the error value is equal + to itself. }]; - let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, - Shape_ShapeOrExtentTensorType:$rhs); + let arguments = (ins Variadic:$shapes); let results = (outs I1:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)"; + let builders = [ + OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs), + [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>, + ]; + let extraClassDeclaration = [{ + // TODO: This should really be automatic. Figure out how to not need this defined. + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) { + inferredReturnTypes.push_back(::mlir::IntegerType::get(context, + /*width=*/1)); + return success(); + }; + }]; + + let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; let hasFolder = 1; } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -474,46 +474,55 @@ LogicalResult ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - // For now, this lowering is only defined on `tensor` operands, not - // on shapes. - if (op.lhs().getType().isa() || - op.rhs().getType().isa()) { + if (!llvm::all_of(op.shapes(), + [](Value v) { return !v.getType().isa(); })) return failure(); + + Type i1Ty = rewriter.getI1Type(); + if (op.shapes().size() <= 1) { + rewriter.replaceOpWithNewOp(op, i1Ty, + rewriter.getBoolAttr(true)); + return success(); } ShapeEqOp::Adaptor transformed(operands); auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); Value zero = rewriter.create(loc, 0); - Value lhsRank = rewriter.create(loc, indexTy, transformed.lhs(), zero); - Value rhsRank = rewriter.create(loc, indexTy, transformed.rhs(), zero); - Value eqRank = - rewriter.create(loc, CmpIPredicate::eq, lhsRank, rhsRank); - Type i1Ty = rewriter.getI1Type(); - rewriter.replaceOpWithNewOp( - op, i1Ty, eqRank, - [&](OpBuilder &b, Location loc) { - Value one = b.create(loc, 1); - Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); - auto loop = b.create( - loc, zero, lhsRank, one, ValueRange{init}, - [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { - Value conj = args[0]; - Value lhsExtent = - b.create(loc, transformed.lhs(), iv); - Value rhsExtent = - b.create(loc, transformed.rhs(), iv); - Value eqExtent = b.create(loc, CmpIPredicate::eq, - lhsExtent, rhsExtent); - Value conjNext = b.create(loc, conj, eqExtent); - b.create(loc, ValueRange({conjNext})); - }); - b.create(loc, loop.getResults()); - }, - [&](OpBuilder &b, Location loc) { - Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); - b.create(loc, result); - }); + Value firstShape = transformed.shapes().front(); + Value firstRank = rewriter.create(loc, indexTy, firstShape, zero); + Value result; + for (Value shape : transformed.shapes().drop_front(1)) { + Value rank = rewriter.create(loc, indexTy, shape, zero); + Value eqRank = + rewriter.create(loc, CmpIPredicate::eq, firstRank, rank); + auto same = rewriter.create( + loc, i1Ty, eqRank, + [&](OpBuilder &b, Location loc) { + Value one = b.create(loc, 1); + Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); + auto loop = b.create( + loc, zero, firstRank, one, ValueRange{init}, + [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { + Value conj = args[0]; + Value lhsExtent = + b.create(loc, firstShape, iv); + Value rhsExtent = b.create(loc, shape, iv); + Value eqExtent = b.create(loc, CmpIPredicate::eq, + lhsExtent, rhsExtent); + Value conjNext = b.create(loc, conj, eqExtent); + b.create(loc, ValueRange({conjNext})); + }); + b.create(loc, loop.getResults()); + }, + [&](OpBuilder &b, Location loc) { + Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); + b.create(loc, result); + }); + result = !result ? same.getResult(0) + : rewriter.create(loc, result, same.getResult(0)); + } + rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -629,15 +629,15 @@ //===----------------------------------------------------------------------===// OpFoldResult ShapeEqOp::fold(ArrayRef operands) { - if (lhs() == rhs()) - return BoolAttr::get(getContext(), true); - auto lhs = operands[0].dyn_cast_or_null(); - if (lhs == nullptr) - return {}; - auto rhs = operands[1].dyn_cast_or_null(); - if (rhs == nullptr) + bool allSame = true; + if (!operands.empty() && !operands[0]) return {}; - return BoolAttr::get(getContext(), lhs == rhs); + for (Attribute operand : operands.drop_front(1)) { + if (!operand) + return {}; + allSame &= operand == operands[0]; + } + return BoolAttr::get(getContext(), allSame); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -295,6 +295,53 @@ // ----- +// CHECK-LABEL: @shape_eq +// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor, %[[C:.*]]: tensor) -> i1 +func @shape_eq(%a : tensor, %b : tensor, %c : tensor) -> i1 { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor + // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor + // CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_B]] + // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) { + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[INIT:.*]] = constant true + // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { + // CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor + // CHECK: %[[EXTENT_B:.*]] = tensor.extract %[[B]][%[[I]]] : tensor + // CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_B]] + // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]] + // CHECK: scf.yield %[[CONJ_NEXT]] : i1 + // CHECK: } + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } else { + // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } + // CHECK: %[[RANK_C:.*]] = dim %[[C]], %[[C0]] : tensor + // CHECK: %[[RANK_EQ:.*]] = cmpi eq, %[[RANK_A]], %[[RANK_C]] + // CHECK: %[[SHAPE_EQ2:.*]] = scf.if %[[RANK_EQ]] -> (i1) { + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[INIT:.*]] = constant true + // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { + // CHECK: %[[EXTENT_A:.*]] = tensor.extract %[[A]][%[[I]]] : tensor + // CHECK: %[[EXTENT_C:.*]] = tensor.extract %[[C]][%[[I]]] : tensor + // CHECK: %[[EXTENT_EQ:.*]] = cmpi eq, %[[EXTENT_A]], %[[EXTENT_C]] + // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]] + // CHECK: scf.yield %[[CONJ_NEXT]] : i1 + // CHECK: } + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } else { + // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } + // CHECK: %[[RESULT:.*]] = and %[[SHAPE_EQ]], %[[SHAPE_EQ2]] : i1 + // CHECK: return %[[RESULT]] : i1 + %result = shape.shape_eq %a, %b, %c : tensor, tensor, tensor + return %result : i1 +} + +// ----- + // Don't lower `shape.broadcast` if a `shape.shape` type is involved. // CHECK-LABEL: @broadcast func @broadcast(%a : tensor, %b : !shape.shape) -> !shape.shape { diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -864,7 +864,8 @@ // CHECK: return %[[RESULT]] : i1 %a = shape.const_shape [1, 2, 3] : !shape.shape %b = shape.const_shape [1, 2, 3] : tensor - %result = shape.shape_eq %a, %b : !shape.shape, tensor + %c = shape.const_shape [1, 2, 3] : tensor + %result = shape.shape_eq %a, %b, %c : !shape.shape, tensor, tensor return %result : i1 } @@ -877,7 +878,8 @@ // CHECK: return %[[RESULT]] : i1 %a = shape.const_shape [1, 2, 3] : tensor %b = shape.const_shape [4, 5, 6] : tensor - %result = shape.shape_eq %a, %b : tensor, tensor + %c = shape.const_shape [4, 5, 6] : tensor + %result = shape.shape_eq %a, %b, %c : tensor, tensor, tensor return %result : i1 } @@ -908,19 +910,6 @@ return %result : i1 } - -// ----- - -// Fold `shape_eq` for non-constant but same shapes. -// CHECK-LABEL: @shape_eq_do_fold -// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1 -func @shape_eq_do_fold(%a : !shape.shape) -> i1 { - // CHECK: %[[RESULT:.*]] = constant true - // CHECK: return %[[RESULT]] : i1 - %result = shape.shape_eq %a, %a : !shape.shape, !shape.shape - return %result : i1 -} - // ----- // Fold `mul` for constant sizes.