diff --git a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h --- a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h +++ b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h @@ -13,13 +13,29 @@ namespace mlir { +class ConversionTarget; class ModuleOp; template class OperationPass; +class MLIRContext; +class OwningRewritePatternList; +class TypeConverter; /// Create a pass to convert Async operations to the LLVM dialect. std::unique_ptr> createConvertAsyncToLLVMPass(); +/// Populates patterns for async structural type conversions. +/// +/// A "structural" type conversion is one where the underlying ops are +/// completely agnostic to the actual types involved and simply need to update +/// their types. An example of this is async.execute -- the async.execute op and +/// the corresponding async.yield ops need to update their types accordingly to +/// the TypeConverter, but otherwise don't care what type conversions are +/// happening. +void populateAsyncStructuralTypeConversionsAndLegality( + MLIRContext *context, TypeConverter &typeConverter, + OwningRewritePatternList &patterns, ConversionTarget &target); + } // namespace mlir #endif // MLIR_CONVERSION_ASYNCTOLLVM_ASYNCTOLLVM_H diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -1136,6 +1136,72 @@ } } // namespace +namespace { +class ConvertExecuteOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ExecuteOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ExecuteOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Set operands and update block argument and result types. + newOp->setOperands(operands); + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) + return failure(); + for (auto result : newOp.getResults()) + result.setType(typeConverter->convertType(result.getType())); + + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +// Dummy pattern to trigger the appropriate type conversion / materialization. +class ConvertAwaitOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AwaitOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, operands.front()); + // rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; + +// Dummy pattern to trigger the appropriate type conversion / materialization. +class ConvertYieldOpTypes : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(async::YieldOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, operands); + return success(); + } +}; +} // namespace + std::unique_ptr> mlir::createConvertAsyncToLLVMPass() { return std::make_unique(); } + +void mlir::populateAsyncStructuralTypeConversionsAndLegality( + MLIRContext *context, TypeConverter &typeConverter, + OwningRewritePatternList &patterns, ConversionTarget &target) { + typeConverter.addConversion([&](TokenType type) { return type; }); + typeConverter.addConversion([&](ValueType type) { + return ValueType::get(typeConverter.convertType(type.getValueType())); + }); + + patterns + .insert( + typeConverter, context); + + target.addDynamicallyLegalOp( + [&](Operation *op) { return typeConverter.isLegal(op); }); +}