diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -21,6 +21,66 @@ using namespace mlir::scf; /// Conversion patterns. +namespace { +/// Converts `shape.eq` to `scf.for` for occurrences that test shape equality. +struct EqOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(EqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult +EqOpConverter::matchAndRewrite(EqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + Type operandTy = op.lhs().getType(); + + // Size equality is lowered to `std` and the pattern can be found in the + // corresponding pass. + if (!operandTy.isa()) + return failure(); + + // Lower size equality. + EqOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + Type indexTy = rewriter.getIndexType(); + Value zero = rewriter.create(loc, 0); + Value lhsRank = rewriter.create(loc, indexTy, transformed.lhs(), zero); + Value rhsRank = rewriter.create(loc, indexTy, transformed.rhs(), zero); + Value eqRank = + rewriter.create(loc, CmpIPredicate::eq, lhsRank, rhsRank); + Type i1Ty = rewriter.getI1Type(); + rewriter.replaceOpWithNewOp( + op, i1Ty, eqRank, + [&](OpBuilder &b, Location loc) { + Value one = b.create(loc, 1); + Attribute trueAttr = b.getI1IntegerAttr(true); + Value init = b.create(loc, i1Ty, trueAttr); + auto loop = b.create( + loc, zero, lhsRank, one, ValueRange{init}, + [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { + Value conj = args[0]; + Value lhsExtent = + b.create(loc, transformed.lhs(), iv); + Value rhsExtent = + b.create(loc, transformed.rhs(), iv); + Value eqExtent = b.create(loc, CmpIPredicate::eq, + lhsExtent, rhsExtent); + Value conjNext = b.create(loc, conj, eqExtent); + b.create(loc, ValueRange({conjNext})); + }); + b.create(loc, loop.getResults()); + }, + [&](OpBuilder &b, Location loc) { + Attribute falseAttr = b.getI1IntegerAttr(false); + Value result = b.create(loc, i1Ty, falseAttr); + b.create(loc, result); + }); + return success(); +} + namespace { /// Converts `shape.reduce` to `scf.for`. struct ReduceOpConverter : public OpConversionPattern { @@ -118,7 +178,7 @@ void mlir::populateShapeToSCFConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } std::unique_ptr mlir::createConvertShapeToSCFPass() { diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -21,3 +21,33 @@ // CHECK-DAG: %[[ACC_NEXT:.*]] = muli %[[ACC]], %[[EXTENT]] // CHECK-DAG: scf.yield %[[ACC_NEXT]] : index // CHECK-DAG: } + +// ----- + +// CHECK-LABEL: @eq_shape +// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor) -> i1 +func @eq_shape(%a : !shape.shape, %b : !shape.shape) -> i1 { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor + // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor + // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]] + // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) { + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[INIT:.*]] = constant true + // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] + // CHECK-SAME: iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { + // CHECK: %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor + // CHECK: %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor + // CHECK: %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]] + // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]] + // CHECK: scf.yield %[[CONJ_NEXT]] : i1 + // CHECK: } + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } else { + // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false + // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 + // CHECK: } + // CHECK: return %[[SHAPE_EQ]] : i1 + %result = shape.eq %a, %b : !shape.shape + return %result : i1 +}