diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -157,6 +157,11 @@ "LLVM::LLVMDialect", "func::FuncDialect", ]; + let options = [ + Option<"useOpaquePointers", "use-opaque-pointers", "bool", + /*default=*/"false", "Generate LLVM IR using opaque pointers " + "instead of typed pointers">, + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" @@ -70,11 +71,14 @@ /// Async Runtime API function types. /// /// Because we can't create API function signature for type parametrized -/// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After +/// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After /// lowering all async data types become opaque pointers at runtime. struct AsyncAPI { - // All async types are lowered to opaque i8* LLVM pointers at runtime. - static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { + // All async types are lowered to opaque LLVM pointers at runtime. + static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + if (useLLVMOpaquePointers) + return LLVM::LLVMPointerType::get(ctx); return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); } @@ -82,8 +86,9 @@ return LLVM::LLVMTokenType::get(ctx); } - static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { - auto ref = opaquePointerType(ctx); + static FunctionType addOrDropRefFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + auto ref = opaquePointerType(ctx, useLLVMOpaquePointers); auto count = IntegerType::get(ctx, 64); return FunctionType::get(ctx, {ref, count}, {}); } @@ -92,9 +97,10 @@ return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); } - static FunctionType createValueFunctionType(MLIRContext *ctx) { + static FunctionType createValueFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { auto i64 = IntegerType::get(ctx, 64); - auto value = opaquePointerType(ctx); + auto value = opaquePointerType(ctx, useLLVMOpaquePointers); return FunctionType::get(ctx, {i64}, {value}); } @@ -103,9 +109,10 @@ return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); } - static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { - auto value = opaquePointerType(ctx); - auto storage = opaquePointerType(ctx); + static FunctionType getValueStorageFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + auto value = opaquePointerType(ctx, useLLVMOpaquePointers); + auto storage = opaquePointerType(ctx, useLLVMOpaquePointers); return FunctionType::get(ctx, {value}, {storage}); } @@ -113,8 +120,9 @@ return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } - static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { - auto value = opaquePointerType(ctx); + static FunctionType emplaceValueFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + auto value = opaquePointerType(ctx, useLLVMOpaquePointers); return FunctionType::get(ctx, {value}, {}); } @@ -122,8 +130,9 @@ return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } - static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { - auto value = opaquePointerType(ctx); + static FunctionType setValueErrorFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + auto value = opaquePointerType(ctx, useLLVMOpaquePointers); return FunctionType::get(ctx, {value}, {}); } @@ -132,8 +141,9 @@ return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); } - static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { - auto value = opaquePointerType(ctx); + static FunctionType isValueErrorFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + auto value = opaquePointerType(ctx, useLLVMOpaquePointers); auto i1 = IntegerType::get(ctx, 1); return FunctionType::get(ctx, {value}, {i1}); } @@ -147,8 +157,9 @@ return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } - static FunctionType awaitValueFunctionType(MLIRContext *ctx) { - auto value = opaquePointerType(ctx); + static FunctionType awaitValueFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + auto value = opaquePointerType(ctx, useLLVMOpaquePointers); return FunctionType::get(ctx, {value}, {}); } @@ -156,9 +167,15 @@ return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); } - static FunctionType executeFunctionType(MLIRContext *ctx) { - auto hdl = opaquePointerType(ctx); - auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); + static FunctionType executeFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers); + Type resume; + if (useLLVMOpaquePointers) + resume = LLVM::LLVMPointerType::get(ctx); + else + resume = LLVM::LLVMPointerType::get( + resumeFunctionType(ctx, useLLVMOpaquePointers)); return FunctionType::get(ctx, {hdl, resume}, {}); } @@ -168,22 +185,42 @@ {i64}); } - static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { - auto hdl = opaquePointerType(ctx); - auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); + static FunctionType + awaitTokenAndExecuteFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers); + Type resume; + if (useLLVMOpaquePointers) + resume = LLVM::LLVMPointerType::get(ctx); + else + resume = LLVM::LLVMPointerType::get( + resumeFunctionType(ctx, useLLVMOpaquePointers)); return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); } - static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { - auto value = opaquePointerType(ctx); - auto hdl = opaquePointerType(ctx); - auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); + static FunctionType + awaitValueAndExecuteFunctionType(MLIRContext *ctx, + bool useLLVMOpaquePointers) { + auto value = opaquePointerType(ctx, useLLVMOpaquePointers); + auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers); + Type resume; + if (useLLVMOpaquePointers) + resume = LLVM::LLVMPointerType::get(ctx); + else + resume = LLVM::LLVMPointerType::get( + resumeFunctionType(ctx, useLLVMOpaquePointers)); return FunctionType::get(ctx, {value, hdl, resume}, {}); } - static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { - auto hdl = opaquePointerType(ctx); - auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); + static FunctionType + awaitAllAndExecuteFunctionType(MLIRContext *ctx, bool useLLVMOpaquePointers) { + auto hdl = opaquePointerType(ctx, useLLVMOpaquePointers); + Type resume; + if (useLLVMOpaquePointers) + resume = LLVM::LLVMPointerType::get(ctx); + else + resume = LLVM::LLVMPointerType::get( + resumeFunctionType(ctx, useLLVMOpaquePointers)); return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); } @@ -192,16 +229,17 @@ } // Auxiliary coroutine resume intrinsic wrapper. - static Type resumeFunctionType(MLIRContext *ctx) { + static Type resumeFunctionType(MLIRContext *ctx, bool useLLVMOpaquePointers) { auto voidTy = LLVM::LLVMVoidType::get(ctx); - auto i8Ptr = opaquePointerType(ctx); - return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); + auto ptrType = opaquePointerType(ctx, useLLVMOpaquePointers); + return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false); } }; } // namespace /// Adds Async Runtime C API declarations to the module. -static void addAsyncRuntimeApiDeclarations(ModuleOp module) { +static void addAsyncRuntimeApiDeclarations(ModuleOp module, + bool useLLVMOpaquePointers) { auto builder = ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); @@ -212,30 +250,39 @@ }; MLIRContext *ctx = module.getContext(); - addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); - addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); + addFuncDecl(kAddRef, + AsyncAPI::addOrDropRefFunctionType(ctx, useLLVMOpaquePointers)); + addFuncDecl(kDropRef, + AsyncAPI::addOrDropRefFunctionType(ctx, useLLVMOpaquePointers)); addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); - addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); + addFuncDecl(kCreateValue, + AsyncAPI::createValueFunctionType(ctx, useLLVMOpaquePointers)); addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); - addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); + addFuncDecl(kEmplaceValue, + AsyncAPI::emplaceValueFunctionType(ctx, useLLVMOpaquePointers)); addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); - addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); + addFuncDecl(kSetValueError, + AsyncAPI::setValueErrorFunctionType(ctx, useLLVMOpaquePointers)); addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); - addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); + addFuncDecl(kIsValueError, + AsyncAPI::isValueErrorFunctionType(ctx, useLLVMOpaquePointers)); addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); - addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); + addFuncDecl(kAwaitValue, + AsyncAPI::awaitValueFunctionType(ctx, useLLVMOpaquePointers)); addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); - addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); - addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); + addFuncDecl(kExecute, + AsyncAPI::executeFunctionType(ctx, useLLVMOpaquePointers)); + addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType( + ctx, useLLVMOpaquePointers)); addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); - addFuncDecl(kAwaitTokenAndExecute, - AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); - addFuncDecl(kAwaitValueAndExecute, - AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); - addFuncDecl(kAwaitAllAndExecute, - AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); + addFuncDecl(kAwaitTokenAndExecute, AsyncAPI::awaitTokenAndExecuteFunctionType( + ctx, useLLVMOpaquePointers)); + addFuncDecl(kAwaitValueAndExecute, AsyncAPI::awaitValueAndExecuteFunctionType( + ctx, useLLVMOpaquePointers)); + addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType( + ctx, useLLVMOpaquePointers)); addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); } @@ -248,7 +295,7 @@ /// A function that takes a coroutine handle and calls a `llvm.coro.resume` /// intrinsics. We need this function to be able to pass it to the async /// runtime execute API. -static void addResumeFunction(ModuleOp module) { +static void addResumeFunction(ModuleOp module, bool useOpaquePointers) { if (module.lookupSymbol(kResume)) return; @@ -257,10 +304,14 @@ auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); auto voidTy = LLVM::LLVMVoidType::get(ctx); - auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + Type ptrType; + if (useOpaquePointers) + ptrType = LLVM::LLVMPointerType::get(ctx); + else + ptrType = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); auto resumeOp = moduleBuilder.create( - kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); + kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(); @@ -278,10 +329,15 @@ /// AsyncRuntimeTypeConverter only converts types from the Async dialect to /// their runtime type (opaque pointers) and does not convert any other types. class AsyncRuntimeTypeConverter : public TypeConverter { + bool llvmOpaquePointers = false; + public: - AsyncRuntimeTypeConverter() { + AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) + : llvmOpaquePointers(options.useOpaquePointers) { addConversion([](Type type) { return type; }); - addConversion(convertAsyncTypes); + addConversion([this](Type type) { + return convertAsyncTypes(type, llvmOpaquePointers); + }); // Use UnrealizedConversionCast as the bridge so that we don't need to pull // in patterns for other dialects. @@ -295,18 +351,52 @@ addTargetMaterialization(addUnrealizedCast); } - static std::optional convertAsyncTypes(Type type) { + /// Returns whether LLVM opaque pointers should be used instead of typed + /// pointers. + bool useOpaquePointers() const { return llvmOpaquePointers; } + + /// Creates an LLVM pointer type which may either be a typed pointer or an + /// opaque pointer, depending on what options the converter was constructed + /// with. + LLVM::LLVMPointerType getPointerType(Type elementType) { + if (llvmOpaquePointers) + return LLVM::LLVMPointerType::get(elementType.getContext()); + return LLVM::LLVMPointerType::get(elementType); + } + + static std::optional convertAsyncTypes(Type type, + bool useOpaquePointers) { if (type.isa()) - return AsyncAPI::opaquePointerType(type.getContext()); + return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers); if (type.isa()) return AsyncAPI::tokenType(type.getContext()); if (type.isa()) - return AsyncAPI::opaquePointerType(type.getContext()); + return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers); return std::nullopt; } }; + +/// Base class for conversion patterns requiring AsyncRuntimeTypeConverter +/// as type converter. Allows access to it via the 'getTypeConverter' +/// convenience method. +template +class AsyncOpConversionPattern : public OpConversionPattern { + + using Base = OpConversionPattern; + +public: + AsyncOpConversionPattern(AsyncRuntimeTypeConverter &typeConverter, + MLIRContext *context) + : Base(typeConverter, context) {} + + /// Returns the 'AsyncRuntimeTypeConverter' of the pattern. + AsyncRuntimeTypeConverter *getTypeConverter() const { + return static_cast(Base::getTypeConverter()); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -314,21 +404,22 @@ //===----------------------------------------------------------------------===// namespace { -class CoroIdOpConversion : public OpConversionPattern { +class CoroIdOpConversion : public AsyncOpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto token = AsyncAPI::tokenType(op->getContext()); - auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); + auto ptrType = AsyncAPI::opaquePointerType( + op->getContext(), getTypeConverter()->useOpaquePointers()); auto loc = op->getLoc(); // Constants for initializing coroutine frame. auto constZero = rewriter.create(loc, rewriter.getI32Type(), 0); - auto nullPtr = rewriter.create(loc, i8Ptr); + auto nullPtr = rewriter.create(loc, ptrType); // Get coroutine id: @llvm.coro.id. rewriter.replaceOpWithNewOp( @@ -344,14 +435,15 @@ //===----------------------------------------------------------------------===// namespace { -class CoroBeginOpConversion : public OpConversionPattern { +class CoroBeginOpConversion : public AsyncOpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); + auto ptrType = AsyncAPI::opaquePointerType( + op->getContext(), getTypeConverter()->useOpaquePointers()); auto loc = op->getLoc(); // Get coroutine frame size: @llvm.coro.size.i64. @@ -379,14 +471,14 @@ // Allocate memory for the coroutine frame. auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( op->getParentOfType(), rewriter.getI64Type(), - /*TODO: opaquePointers=*/false); + getTypeConverter()->useOpaquePointers()); auto coroAlloc = rewriter.create( loc, allocFuncOp, ValueRange{coroAlign, coroSize}); // Begin a coroutine: @llvm.coro.begin. auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); rewriter.replaceOpWithNewOp( - op, i8Ptr, ValueRange({coroId, coroAlloc.getResult()})); + op, ptrType, ValueRange({coroId, coroAlloc.getResult()})); return success(); } @@ -398,23 +490,25 @@ //===----------------------------------------------------------------------===// namespace { -class CoroFreeOpConversion : public OpConversionPattern { +class CoroFreeOpConversion : public AsyncOpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); + auto ptrType = AsyncAPI::opaquePointerType( + op->getContext(), getTypeConverter()->useOpaquePointers()); auto loc = op->getLoc(); // Get a pointer to the coroutine frame memory: @llvm.coro.free. auto coroMem = - rewriter.create(loc, i8Ptr, adaptor.getOperands()); + rewriter.create(loc, ptrType, adaptor.getOperands()); // Free the memory. - auto freeFuncOp = LLVM::lookupOrCreateFreeFn( - op->getParentOfType(), /*TODO: opaquePointers=*/false); + auto freeFuncOp = + LLVM::lookupOrCreateFreeFn(op->getParentOfType(), + getTypeConverter()->useOpaquePointers()); rewriter.replaceOpWithNewOp(op, freeFuncOp, ValueRange(coroMem.getResult())); @@ -551,9 +645,9 @@ //===----------------------------------------------------------------------===// namespace { -class RuntimeCreateOpLowering : public OpConversionPattern { +class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern { public: - using OpConversionPattern::OpConversionPattern; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, @@ -576,13 +670,14 @@ auto i64 = rewriter.getI64Type(); auto storedType = converter->convertType(valueType.getValueType()); - auto storagePtrType = LLVM::LLVMPointerType::get(storedType); + auto storagePtrType = getTypeConverter()->getPointerType(storedType); // %Size = getelementptr %T* null, int 1 // %SizeI = ptrtoint %T* %Size to i64 auto nullPtr = rewriter.create(loc, storagePtrType); - auto gep = rewriter.create(loc, storagePtrType, nullPtr, - ArrayRef{1}); + auto gep = + rewriter.create(loc, storagePtrType, storedType, + nullPtr, ArrayRef{1}); return rewriter.create(loc, i64, gep); }; @@ -603,9 +698,9 @@ namespace { class RuntimeCreateGroupOpLowering - : public OpConversionPattern { + : public ConvertOpToLLVMPattern { public: - using OpConversionPattern::OpConversionPattern; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, @@ -731,9 +826,9 @@ namespace { class RuntimeAwaitAndResumeOpLowering - : public OpConversionPattern { + : public AsyncOpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, @@ -748,10 +843,12 @@ Value handle = adaptor.getHandle(); // A pointer to coroutine resume intrinsic wrapper. - addResumeFunction(op->getParentOfType()); - auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); + addResumeFunction(op->getParentOfType(), + getTypeConverter()->useOpaquePointers()); + auto resumeFnTy = AsyncAPI::resumeFunctionType( + op->getContext(), getTypeConverter()->useOpaquePointers()); auto resumePtr = rewriter.create( - op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); + op->getLoc(), getTypeConverter()->getPointerType(resumeFnTy), kResume); rewriter.create( op->getLoc(), apiFuncName, TypeRange(), @@ -768,18 +865,21 @@ //===----------------------------------------------------------------------===// namespace { -class RuntimeResumeOpLowering : public OpConversionPattern { +class RuntimeResumeOpLowering + : public AsyncOpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // A pointer to coroutine resume intrinsic wrapper. - addResumeFunction(op->getParentOfType()); - auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); + addResumeFunction(op->getParentOfType(), + getTypeConverter()->useOpaquePointers()); + auto resumeFnTy = AsyncAPI::resumeFunctionType( + op->getContext(), getTypeConverter()->useOpaquePointers()); auto resumePtr = rewriter.create( - op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); + op->getLoc(), getTypeConverter()->getPointerType(resumeFnTy), kResume); // Call async runtime API to execute a coroutine in the managed thread. auto coroHdl = adaptor.getHandle(); @@ -796,9 +896,9 @@ //===----------------------------------------------------------------------===// namespace { -class RuntimeStoreOpLowering : public OpConversionPattern { +class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern { public: - using OpConversionPattern::OpConversionPattern; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, @@ -806,10 +906,11 @@ Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. - auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); + auto ptrType = AsyncAPI::opaquePointerType( + rewriter.getContext(), getTypeConverter()->useOpaquePointers()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create(loc, kGetValueStorage, - TypeRange(i8Ptr), storage); + auto storagePtr = rewriter.create( + loc, kGetValueStorage, TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getValue().getType(); @@ -818,13 +919,15 @@ return rewriter.notifyMatchFailure( op, "failed to convert stored value type to LLVM type"); - auto castedStoragePtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(llvmValueType), - storagePtr.getResult(0)); + Value castedStoragePtr = storagePtr.getResult(0); + if (!getTypeConverter()->useOpaquePointers()) + castedStoragePtr = rewriter.create( + loc, getTypeConverter()->getPointerType(llvmValueType), + castedStoragePtr); // Store the yielded value into the async value storage. auto value = adaptor.getValue(); - rewriter.create(loc, value, castedStoragePtr.getResult()); + rewriter.create(loc, value, castedStoragePtr); // Erase the original runtime store operation. rewriter.eraseOp(op); @@ -839,9 +942,9 @@ //===----------------------------------------------------------------------===// namespace { -class RuntimeLoadOpLowering : public OpConversionPattern { +class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern { public: - using OpConversionPattern::OpConversionPattern; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, @@ -849,10 +952,11 @@ Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. - auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); + auto ptrType = AsyncAPI::opaquePointerType( + rewriter.getContext(), getTypeConverter()->useOpaquePointers()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create(loc, kGetValueStorage, - TypeRange(i8Ptr), storage); + auto storagePtr = rewriter.create( + loc, kGetValueStorage, TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getResult().getType(); @@ -861,12 +965,15 @@ return rewriter.notifyMatchFailure( op, "failed to convert loaded value type to LLVM type"); - auto castedStoragePtr = rewriter.create( - loc, LLVM::LLVMPointerType::get(llvmValueType), - storagePtr.getResult(0)); + Value castedStoragePtr = storagePtr.getResult(0); + if (!getTypeConverter()->useOpaquePointers()) + castedStoragePtr = rewriter.create( + loc, getTypeConverter()->getPointerType(llvmValueType), + castedStoragePtr); // Load from the casted pointer. - rewriter.replaceOpWithNewOp(op, castedStoragePtr.getResult()); + rewriter.replaceOpWithNewOp(op, llvmValueType, + castedStoragePtr); return success(); } @@ -1000,22 +1107,28 @@ ModuleOp module = getOperation(); MLIRContext *ctx = module->getContext(); + LowerToLLVMOptions options(ctx); + options.useOpaquePointers = useOpaquePointers; + // Add declarations for most functions required by the coroutines lowering. // We delay adding the resume function until it's needed because it currently // fails to compile unless '-O0' is specified. - addAsyncRuntimeApiDeclarations(module); + addAsyncRuntimeApiDeclarations(module, useOpaquePointers); // Lower async.runtime and async.coro operations to Async Runtime API and // LLVM coroutine intrinsics. // Convert async dialect types and operations to LLVM dialect. - AsyncRuntimeTypeConverter converter; + AsyncRuntimeTypeConverter converter(options); RewritePatternSet patterns(ctx); // We use conversion to LLVM type to lower async.runtime load and store // operations. - LLVMTypeConverter llvmConverter(ctx); - llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); + LLVMTypeConverter llvmConverter(ctx, options); + llvmConverter.addConversion([&](Type type) { + return AsyncRuntimeTypeConverter::convertAsyncTypes( + type, llvmConverter.useOpaquePointers()); + }); // Convert async types in function signatures and function calls. populateFunctionOpInterfaceTypeConversionPattern(patterns, @@ -1036,8 +1149,7 @@ // Lower async.runtime operations that rely on LLVM type converter to convert // from async value payload type to the LLVM type. patterns.add(llvmConverter, - ctx); + RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter); // Lower async coroutine operations to LLVM coroutine intrinsics. patterns diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir @@ -1,9 +1,9 @@ -// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s +// RUN: mlir-opt %s -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s // CHECK-LABEL: @coro_id func.func @coro_id() { // CHECK: %0 = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %1 = llvm.mlir.null : !llvm.ptr + // CHECK: %1 = llvm.mlir.null : !llvm.ptr // CHECK: %2 = llvm.intr.coro.id %0, %1, %1, %1 : !llvm.token %0 = async.coro.id return diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always +// RUN: mlir-opt %s -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s --dump-input=always // CHECK-LABEL: @create_token func.func @create_token() { @@ -9,7 +9,7 @@ // CHECK-LABEL: @create_value func.func @create_value() { - // CHECK: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr + // CHECK: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr // CHECK: %[[OFFSET:.*]] = llvm.getelementptr %[[NULL]][1] // CHECK: %[[SIZE:.*]] = llvm.ptrtoint %[[OFFSET]] // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue(%[[SIZE]]) @@ -152,8 +152,7 @@ // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue %1 = async.runtime.create : !async.value // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]]) - // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr to !llvm.ptr - // CHECK: llvm.store %[[CST]], %[[P1]] + // CHECK: llvm.store %[[CST]], %[[P0]] : f32, !llvm.ptr async.runtime.store %0, %1 : !async.value return } @@ -163,8 +162,7 @@ // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue %0 = async.runtime.create : !async.value // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]]) - // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr to !llvm.ptr - // CHECK: %[[VALUE:.*]] = llvm.load %[[P1]] + // CHECK: %[[VALUE:.*]] = llvm.load %[[P0]] : !llvm.ptr -> f32 %1 = async.runtime.load %0 : !async.value // CHECK: return %[[VALUE]] : f32 return %1 : f32 diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm | FileCheck %s +// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm='use-opaque-pointers=1' | FileCheck %s // CHECK-LABEL: reference_counting func.func @reference_counting(%arg0: !async.token) { @@ -35,7 +35,7 @@ // Function outlined from the async.execute operation. // CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>) -// CHECK-SAME: -> !llvm.ptr +// CHECK-SAME: -> !llvm.ptr // Create token for return op, and mark a function as a coroutine. // CHECK: %[[RET:.*]] = call @mlirAsyncRuntimeCreateToken() @@ -97,7 +97,7 @@ // Function outlined from the inner async.execute operation. // CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>) -// CHECK-SAME: -> !llvm.ptr +// CHECK-SAME: -> !llvm.ptr // CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken() // CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin // CHECK: call @mlirAsyncRuntimeExecute @@ -108,7 +108,7 @@ // Function outlined from the outer async.execute operation. // CHECK-LABEL: func private @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32) -// CHECK-SAME: -> !llvm.ptr +// CHECK-SAME: -> !llvm.ptr // CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken() // CHECK: %[[HDL_1:.*]] = llvm.intr.coro.begin @@ -147,7 +147,7 @@ // Function outlined from the first async.execute operation. // CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>) -// CHECK-SAME: -> !llvm.ptr +// CHECK-SAME: -> !llvm.ptr // CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken() // CHECK: %[[HDL_0:.*]] = llvm.intr.coro.begin // CHECK: call @mlirAsyncRuntimeExecute @@ -156,8 +156,8 @@ // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]]) // Function outlined from the second async.execute operation with dependency. -// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr, %arg1: f32, %arg2: memref<1xf32>) -// CHECK-SAME: -> !llvm.ptr +// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr, %arg1: f32, %arg2: memref<1xf32>) +// CHECK-SAME: -> !llvm.ptr // CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken() // CHECK: %[[HDL_1:.*]] = llvm.intr.coro.begin @@ -200,7 +200,7 @@ } // Function outlined from the async.execute operation. -// CHECK: func private @async_execute_fn_0(%arg0: !llvm.ptr) +// CHECK: func private @async_execute_fn_0(%arg0: !llvm.ptr) // CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken() // CHECK: %[[HDL_1:.*]] = llvm.intr.coro.begin @@ -227,8 +227,7 @@ } // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[RET]]#1) - // CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]] - // CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] : !llvm.ptr + // CHECK: %[[LOADED:.*]] = llvm.load %[[STORAGE]] : !llvm.ptr -> f32 %0 = async.await %result : !async.value return %0 : f32 @@ -247,8 +246,7 @@ // Emplace result value. // CHECK: %[[CST:.*]] = arith.constant 1.230000e+02 : f32 // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]]) -// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]] -// CHECK: llvm.store %[[CST]], %[[ST_F32]] : !llvm.ptr +// CHECK: llvm.store %[[CST]], %[[STORAGE]] : f32, !llvm.ptr // CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]]) // Emplace result token. @@ -280,7 +278,7 @@ // CHECK-LABEL: func private @async_execute_fn() // Function outlined from the second async.execute operation. -// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr) +// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr) // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken() // CHECK: %[[HDL:.*]] = llvm.intr.coro.begin @@ -295,8 +293,7 @@ // Get the operand value storage, cast to f32 and add the value. // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%arg0) -// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]] -// CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] : !llvm.ptr +// CHECK: %[[LOADED:.*]] = llvm.load %[[STORAGE]] : !llvm.ptr -> f32 // CHECK: arith.addf %[[LOADED]], %[[LOADED]] : f32 // Emplace result token. diff --git a/mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir b/mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/AsyncToLLVM/typed-pointers.mlir @@ -0,0 +1,138 @@ +// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -convert-async-to-llvm='use-opaque-pointers=0' | FileCheck %s + + + +// CHECK-LABEL: @store +func.func @store() { + // CHECK: %[[CST:.*]] = arith.constant 1.0 + %0 = arith.constant 1.0 : f32 + // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue + %1 = async.runtime.create : !async.value + // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]]) + // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr to !llvm.ptr + // CHECK: llvm.store %[[CST]], %[[P1]] + async.runtime.store %0, %1 : !async.value + return +} + +// CHECK-LABEL: @load +func.func @load() -> f32 { + // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue + %0 = async.runtime.create : !async.value + // CHECK: %[[P0:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]]) + // CHECK: %[[P1:.*]] = llvm.bitcast %[[P0]] : !llvm.ptr to !llvm.ptr + // CHECK: %[[VALUE:.*]] = llvm.load %[[P1]] + %1 = async.runtime.load %0 : !async.value + // CHECK: return %[[VALUE]] : f32 + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: execute_no_async_args +func.func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { + // CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1) + %token = async.execute { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]]) + // CHECK: %[[IS_ERROR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]]) + // CHECK: %[[TRUE:.*]] = arith.constant true + // CHECK: %[[NOT_ERROR:.*]] = arith.xori %[[IS_ERROR]], %[[TRUE]] : i1 + // CHECK: cf.assert %[[NOT_ERROR]] + // CHECK-NEXT: return + async.await %token : !async.token + return +} + +// Function outlined from the async.execute operation. +// CHECK-LABEL: func private @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>) +// CHECK-SAME: -> !llvm.ptr + +// Create token for return op, and mark a function as a coroutine. +// CHECK: %[[RET:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin + +// Pass a suspended coroutine to the async runtime. +// CHECK: %[[STATE:.*]] = llvm.intr.coro.save +// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume +// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], %[[RESUME]]) +// CHECK: %[[SUSPENDED:.*]] = llvm.intr.coro.suspend %[[STATE]] + +// Decide the next block based on the code returned from suspend. +// CHECK: %[[SEXT:.*]] = llvm.sext %[[SUSPENDED]] : i8 to i32 +// CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]] +// CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]] +// CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]] + +// Resume coroutine after suspension. +// CHECK: ^[[RESUME]]: +// CHECK: memref.store %arg0, %arg1[%c0] : memref<1xf32> +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET]]) + +// Delete coroutine. +// CHECK: ^[[CLEANUP]]: +// CHECK: %[[MEM:.*]] = llvm.intr.coro.free +// CHECK: llvm.call @free(%[[MEM]]) + +// Suspend coroutine, and also a return statement for ramp function. +// CHECK: ^[[SUSPEND]]: +// CHECK: llvm.intr.coro.end +// CHECK: return %[[RET]] + +// ----- + +// CHECK-LABEL: execute_and_return_f32 +func.func @execute_and_return_f32() -> f32 { + // CHECK: %[[RET:.*]]:2 = call @async_execute_fn + %token, %result = async.execute -> !async.value { + %c0 = arith.constant 123.0 : f32 + async.yield %c0 : f32 + } + + // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[RET]]#1) + // CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]] + // CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] : !llvm.ptr + %0 = async.await %result : !async.value + + return %0 : f32 +} + +// Function outlined from the async.execute operation. +// CHECK-LABEL: func private @async_execute_fn() +// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue +// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], +// CHECK: llvm.intr.coro.suspend + +// Emplace result value. +// CHECK: %[[CST:.*]] = arith.constant 1.230000e+02 : f32 +// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]]) +// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]] +// CHECK: llvm.store %[[CST]], %[[ST_F32]] : !llvm.ptr +// CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]]) + +// Emplace result token. +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]]) + +// ----- + +// CHECK-LABEL: @await_and_resume_group +func.func @await_and_resume_group() { + %c = arith.constant 1 : index + %0 = async.coro.id + // CHECK: %[[HDL:.*]] = llvm.intr.coro.begin + %1 = async.coro.begin %0 + // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup + %2 = async.runtime.create_group %c : !async.group + // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume + // CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute + // CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]]) + async.runtime.await_and_resume %2, %1 : !async.group + return +}