diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -19,6 +19,92 @@ using namespace mlir::shape; using namespace mlir::scf; +namespace { +/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is +/// only defined on `tensor` operands. The test for equality first +/// compares their size and, if equal, checks every extent for equality. +/// +/// Example: +/// +/// %result = shape.shape_eq %a, %b : tensor, tensor +/// +/// becomes +/// +/// %c0 = constant 0 : index +/// %0 = dim %arg0, %c0 : tensor +/// %1 = dim %arg1, %c0 : tensor +/// %2 = cmpi "eq", %0, %1 : index +/// %result = scf.if %2 -> (i1) { +/// %c1 = constant 1 : index +/// %true = constant true +/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { +/// %5 = extract_element %arg0[%arg2] : tensor +/// %6 = extract_element %arg1[%arg2] : tensor +/// %7 = cmpi "eq", %5, %6 : index +/// %8 = and %arg3, %7 : i1 +/// scf.yield %8 : i1 +/// } +/// scf.yield %4 : i1 +/// } else { +/// %false = constant false +/// scf.yield %false : i1 +/// } +/// +struct ShapeEqOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ShapeEqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult +ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // For now, this lowering is only defined on `tensor` operands, not + // on shapes. + if (op.lhs().getType().isa() || + op.rhs().getType().isa()) { + return failure(); + } + + ShapeEqOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + Type indexTy = rewriter.getIndexType(); + Value zero = rewriter.create(loc, 0); + Value lhsRank = rewriter.create(loc, indexTy, transformed.lhs(), zero); + Value rhsRank = rewriter.create(loc, indexTy, transformed.rhs(), zero); + Value eqRank = + rewriter.create(loc, CmpIPredicate::eq, lhsRank, rhsRank); + Type i1Ty = rewriter.getI1Type(); + rewriter.replaceOpWithNewOp( + op, i1Ty, eqRank, + [&](OpBuilder &b, Location loc) { + Value one = b.create(loc, 1); + Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); + auto loop = b.create( + loc, zero, lhsRank, one, ValueRange{init}, + [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { + Value conj = args[0]; + Value lhsExtent = + b.create(loc, transformed.lhs(), iv); + Value rhsExtent = + b.create(loc, transformed.rhs(), iv); + Value eqExtent = b.create(loc, CmpIPredicate::eq, + lhsExtent, rhsExtent); + Value conjNext = b.create(loc, conj, eqExtent); + b.create(loc, ValueRange({conjNext})); + }); + b.create(loc, loop.getResults()); + }, + [&](OpBuilder &b, Location loc) { + Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); + b.create(loc, result); + }); + return success(); +} + namespace { /// Converts `shape.reduce` to `scf.for`. struct ReduceOpConverter : public OpConversionPattern { @@ -149,7 +235,12 @@ void mlir::populateShapeToSCFConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + // clang-format off + patterns.insert< + ShapeEqOpConverter, + ReduceOpConverter, + ShapeOfOpConverter>(ctx); + // clang-format on } std::unique_ptr mlir::createConvertShapeToSCFPass() { diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -43,3 +43,31 @@ return } +// ----- + +// CHECK-LABEL: @shape_eq +// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor) -> i1 +func @shape_eq(%a : tensor, %b : tensor) -> i1 { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor + // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor + // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]] + // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) { + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[INIT:.*]] = constant true + // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { + // CHECK: %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor + // CHECK: %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor + // CHECK: %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]] + // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]] + // CHECK: scf.yield %[[CONJ_NEXT]] : i1 + // CHECK: } + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } else { + // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } + // CHECK: return %[[SHAPE_EQ]] : i1 + %result = shape.shape_eq %a, %b : tensor, tensor + return %result : i1 +}