diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -18,9 +18,11 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -53,4 +55,19 @@ } // namespace async } // namespace mlir +namespace llvm { + +/// Allow stealing the low bits of async::FuncOp. +template <> +struct PointerLikeTypeTraits { + static inline void *getAsVoidPointer(mlir::async::FuncOp val) { + return const_cast(val.getAsOpaquePointer()); + } + static inline mlir::async::FuncOp getFromVoidPointer(void *p) { + return mlir::async::FuncOp::getFromOpaquePointer(p); + } + static constexpr int numLowBitsAvailable = 3; +}; +} // namespace llvm + #endif // MLIR_DIALECT_ASYNC_IR_ASYNC_H diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -18,6 +18,11 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/FunctionInterfaces.td" +include "mlir/IR/OpAsmInterface.td" + //===----------------------------------------------------------------------===// // Async op definitions @@ -99,6 +104,189 @@ }]; } +def Async_FuncOp : Async_Op<"func", + [CallableOpInterface, FunctionOpInterface, + IsolatedFromAbove, OpAsmOpInterface, Symbol]> { + let summary = "async function operation"; + let description = [{ + An async function is like a normal function, but supports non-blocking + await. Internally, async function is lowered to the LLVM coroutinue with + async runtime intrinsic. It can return an async token and/or async values. + The token represents the execution state of async function and can be used + when users want to express dependencies on some side effects, e.g., + the token becomes available once every thing in the func body is executed. + + Example: + + ```mlir + // Async function can't return void, it always must be some async thing. + async.func @async.0() -> !async.token { + return + } + + // Function returns only async value. + async.func @async.1() -> !async.value { + %0 = arith.constant 42 : i32 + return %0 : i32 + } + + // Implicit token can be added to return types. + async.func @async.2() -> !async.token, !async.value { + %0 = arith.constant 42 : i32 + return %0 : i32 + } + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$sym_visibility); + + let regions = (region AnyRegion:$body); + + let builders = [ + OpBuilder<(ins "StringRef":$name, "FunctionType":$type, + CArg<"ArrayRef", "{}">:$attrs, + CArg<"ArrayRef", "{}">:$argAttrs)> + ]; + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // CallableOpInterface + //===------------------------------------------------------------------===// + + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr + : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + ArrayRef getCallableResults() { return getFunctionType() + .getResults(); } + + //===------------------------------------------------------------------===// + // FunctionOpInterface Methods + //===------------------------------------------------------------------===// + + /// Returns the argument types of this async function. + ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + + /// Returns the result types of this async function. + ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + /// Returns the number of results of this async function + unsigned getNumResults() {return getResultTypes().size();} + + /// Is the async func stateful + bool isStateful() { return isa(getFunctionType().getResult(0));} + + //===------------------------------------------------------------------===// + // OpAsmOpInterface Methods + //===------------------------------------------------------------------===// + + /// Allow the dialect prefix to be omitted. + static StringRef getDefaultDialect() { return "async"; } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } + }]; + let hasCustomAssemblyFormat = 1; + + let hasVerifier = 1; +} + +def Async_CallOp : Async_Op<"call", + [CallOpInterface, DeclareOpInterfaceMethods]> { + let summary = "async call operation"; + let description = [{ + The `async.call` operation represents a direct call to an async function + that is within the same symbol scope as the call. The operands and result + types of the call must match the specified async function type. The callee + is encoded as a symbol reference attribute named "callee". + + Example: + + ```mlir + %2 = async.call @my_add(%0, %1) : (f32, f32) -> !async.value + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + $_state.addTypes(callee.getFunctionType().getResults()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(results); + }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, + OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); + }]> + ]; + + let extraClassDeclaration = [{ + FunctionType getCalleeType(); + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +def Async_ReturnOp : Async_Op<"return", + [Pure, HasParent<"FuncOp">, ReturnLike, Terminator]> { + let summary = "Async function return operation"; + let description = [{ + The `async.return` is a special terminator operation for Async function. + + Example: + + ```mlir + async.func @foo() : !async.token { + return + } + ``` + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder<(ins), [{build($_builder, $_state, llvm::None);}]>]; + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let hasVerifier = 1; +} + def Async_YieldOp : Async_Op<"yield", [ HasParent<"ExecuteOp">, Pure, Terminator, diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -8,7 +8,10 @@ #include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/FunctionImplementation.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -320,6 +323,134 @@ return success(); } +//===----------------------------------------------------------------------===// +// FuncOp +//===----------------------------------------------------------------------===// + +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(FunctionOpInterface::getTypeAttrName(), + TypeAttr::get(type)); + + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); + function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, + /*resultAttrs=*/llvm::None); +} + +ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, buildFuncType); +} + +void FuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); +} + +/// Check that the result type of async.func is not void and must be +/// some async token or async values. +LogicalResult FuncOp::verify() { + auto resultTypes = getResultTypes(); + if (resultTypes.empty()) + return emitOpError() + << "result is expected to be at least of size 1, but got " + << resultTypes.size(); + + for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) { + auto type = resultTypes[i]; + if (!type.isa() && !type.isa()) + return emitOpError() << "result type must be async value type or async " + "token type, but got " + << type; + // We only allow AsyncToken appear as the first return value + if (type.isa() && i != 0) { + return emitOpError() + << " results' (optional) async token type is expected " + "to appear as the 1st return value, but got " + << i + 1; + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +/// CallOp +//===----------------------------------------------------------------------===// + +LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this)->getAttrOfType("callee"); + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid async function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " + << getOperand(i).getType() << " for operand number " << i; + + if (fnType.getNumResults() != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + } + + return success(); +} + +FunctionType CallOp::getCalleeType() { + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); +} + +//===----------------------------------------------------------------------===// +/// ReturnOp +//===----------------------------------------------------------------------===// + +LogicalResult ReturnOp::verify() { + auto funcOp = (*this)->getParentOfType(); + ArrayRef resultTypes = funcOp.isStateful() + ? funcOp.getResultTypes().drop_front() + : funcOp.getResultTypes(); + // Get the underlying value types from async types returned from the + // parent `async.func` operation. + auto types = llvm::map_range(resultTypes, [](const Type &result) { + return result.cast().getValueType(); + }); + + if (getOperandTypes() != types) + return emitOpError("operand types do not match the types returned from " + "the parent FuncOp"); + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -136,3 +136,34 @@ %3 = arith.addi %1, %2 : index return %3 : index } + +// CHECK-LABEL: @async_func_return_token +async.func @async_func_return_token() -> !async.token { + // CHECK: return + return +} + +// CHECK-LABEL: @async_func_return_value +async.func @async_func_return_value() -> !async.value { + %0 = arith.constant 42 : i32 + // CHECK: return %[[value:.*]] : i32 + return %0 : i32 +} + +// CHECK-LABEL: @async_func_return_optional_token +async.func @async_func_return_optional_token() -> (!async.token, !async.value) { + %0 = arith.constant 42 : i32 + // CHECK: return %[[value:.*]] : i32 + return %0 : i32 +} + +// CHECK-LABEL: @async_call +func.func @async_call() { + // CHECK: async.call @async_func_return_token + // CHECK: async.call @async_func_return_value + // CHECK: async.call @async_func_return_optional_token + %0 = async.call @async_func_return_token() : () -> !async.token + %1 = async.call @async_func_return_value() : () -> !async.value + %2, %3 = async.call @async_func_return_optional_token() : () -> (!async.token, !async.value) + return +} diff --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir --- a/mlir/test/Dialect/Async/verify.mlir +++ b/mlir/test/Dialect/Async/verify.mlir @@ -19,3 +19,29 @@ // expected-error @+1 {{'async.await' op result type 'f64' does not match async value type 'f32'}} %0 = "async.await"(%arg0): (!async.value) -> f64 } + + +// ----- +// expected-error @+1 {{'async.func' op result is expected to be at least of size 1, but got 0}} +async.func @wrong_async_func_void_result_type(%arg0: f32) { + return +} + + +// ----- +// expected-error @+1 {{'async.func' op result type must be async value type or async token type, but got 'f32'}} +async.func @wrong_async_func_result_type(%arg0: f32) -> f32 { + return %arg0 : f32 +} + +// ----- +// expected-error @+1 {{'async.func' op results' (optional) async token type is expected to appear as the 1st return value, but got 2}} +async.func @wrong_async_func_token_type_placement(%arg0: f32) -> (!async.value, !async.token) { + return %arg0 : f32 +} + +// ----- +async.func @wrong_async_func_return_type(%arg0: f32) -> (!async.token, !async.value) { + // expected-error @+1 {{'async.return' op operand types do not match the types returned from the parent FuncOp}} + return %arg0 : f32 +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1349,7 +1349,9 @@ ], includes = ["include"], deps = [ + ":CallInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":FunctionInterfacesTdFiles", ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles",