diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -133,10 +133,53 @@ }; } // namespace +namespace { +class ConvertWhileOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(WhileOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + WhileOp::Adaptor adaptor(operands); + auto newOp = rewriter.create(op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; +} // namespace + +namespace { +class ConvertConditionOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ConditionOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); + return success(); + } +}; +} // namespace + void mlir::scf::populateSCFStructuralTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { - patterns.add( + patterns.add( typeConverter, patterns.getContext()); target.addDynamicallyLegalOp([&](Operation *op) { return typeConverter.isLegal(op->getResultTypes()); @@ -144,8 +187,10 @@ target.addDynamicallyLegalOp([&](scf::YieldOp op) { // We only have conversions for a subset of ops that use scf.yield // terminators. - if (!isa(op->getParentOp())) + if (!isa(op->getParentOp())) return true; return typeConverter.isLegal(op.getOperandTypes()); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { return typeConverter.isLegal(op); }); } diff --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir --- a/mlir/test/Dialect/SCF/bufferize.mlir +++ b/mlir/test/Dialect/SCF/bufferize.mlir @@ -79,3 +79,25 @@ } return %ret : tensor } + +// CHECK-LABEL: func @bufferize_while( +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64, %[[ARG2:.*]]: tensor +// CHECK: %[[M:.*]] = memref.buffer_cast %[[ARG2]] : memref +// CHECK: %[[RES1:.*]]:3 = scf.while (%{{.*}} = %[[ARG0]], %{{.*}} = %[[M]]) : (i64, memref) -> (i64, i64, memref) +// CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}}, %{{.*}} : i64, i64, memref +// CHECK: ^bb0(%{{.*}}: i64, %{{.*}}: i64, %{{.*}}: memref): +// CHECK: scf.yield %{{.*}}, %{{.*}} : i64, memref +// CHECK: %[[RES2:.*]] = memref.tensor_load %[[RES1]]#2 : memref +// CHECK: return %[[RES1]]#1, %[[RES2]] : i64, tensor +func @bufferize_while(%arg0: i64, %arg1: i64, %arg2: tensor) -> (i64, tensor) { + %c2_i64 = constant 2 : i64 + %0:3 = scf.while (%arg3 = %arg0, %arg4 = %arg2) : (i64, tensor) -> (i64, i64, tensor) { + %1 = cmpi slt, %arg3, %arg1 : i64 + scf.condition(%1) %arg3, %arg3, %arg4 : i64, i64, tensor + } do { + ^bb0(%arg5: i64, %arg6: i64, %arg7: tensor): + %1 = muli %arg6, %c2_i64 : i64 + scf.yield %1, %arg7 : i64, tensor + } + return %0#1, %0#2 : i64, tensor +}