diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -14,9 +14,11 @@ #ifndef MLIR_DIALECT_ASYNC_IR_ASYNC_H #define MLIR_DIALECT_ASYNC_IR_ASYNC_H +#include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td @@ -56,7 +56,11 @@ Type valueType = type; } -def Async_AnyValueType : Type()">, - "async value type">; +def Async_AnyValueType : DialectType()">, + "async value type">; + +def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType, + Async_TokenType]>; #endif // ASYNC_BASE_TD diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -24,7 +24,7 @@ class Async_Op traits = []> : Op; -def Async_ExecuteOp : Async_Op<"execute"> { +def Async_ExecuteOp : Async_Op<"execute", [AttrSizedOperandSegments]> { let summary = "Asynchronous execute operation"; let description = [{ The `body` region attached to the `async.execute` operation semantically @@ -40,24 +40,43 @@ state). All dependencies must be made explicit with async execute arguments (`async.token` or `async.value`). + `async.execute` operation takes `async.token` dependencies and `async.value` + operands separatly, and starts execution of the attached body region only + when all tokens and values become ready. + + Example: + ```mlir - %done, %values = async.execute { - %0 = "compute0"(...) : !some.type - async.yield %1 : f32 - } : !async.token, !async.value + %dependency = ... : !async.token + %value = ... : !async.value + + %token, %results = + async.execute [%dependency](%value as %unwrapped: !async.value) + -> !async.value + { + %0 = "compute0"(%unwrapped): (f32) -> !some.type + async.yield %0 : !some.type + } %1 = "compute1"(...) : !some.type ``` + + In the example above asynchronous execution starts only after dependency + token and value argument become ready. Unwrapped value passed to the + attached body region as an %unwrapped value of f32 type. }]; - // TODO: Take async.tokens/async.values as arguments. - let arguments = (ins ); - let results = (outs Async_TokenType:$done, - Variadic:$values); + let arguments = (ins Variadic:$dependencies, + Variadic:$operands); + + let results = (outs Async_TokenType:$token, + Variadic:$results); let regions = (region SizedRegion<1>:$body); - let printer = [{ return ::mlir::async::print(p, *this); }]; - let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; + + let verifier = [{ return ::verify(*this); }]; } def Async_YieldOp : @@ -72,7 +91,7 @@ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; - let verifier = [{ return ::mlir::async::verify(*this); }]; + let verifier = [{ return ::verify(*this); }]; } #endif // ASYNC_OPS diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -8,19 +8,11 @@ #include "mlir/Dialect/Async/IR/Async.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Traits.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/raw_ostream.h" -namespace mlir { -namespace async { +using namespace mlir; +using namespace mlir::async; void AsyncDialect::initialize() { addOperations< @@ -69,6 +61,8 @@ /// ValueType //===----------------------------------------------------------------------===// +namespace mlir { +namespace async { namespace detail { // Storage for `async.value` type, the only member is the wrapped type. @@ -90,6 +84,8 @@ }; } // namespace detail +} // namespace async +} // namespace mlir ValueType ValueType::get(Type valueType) { return Base::get(valueType.getContext(), valueType); @@ -105,7 +101,7 @@ // Get the underlying value types from async values returned from the // parent `async.execute` operation. auto executeOp = op.getParentOfType(); - auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) { + auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { return result.getType().cast().getValueType(); }); @@ -120,49 +116,138 @@ /// ExecuteOp //===----------------------------------------------------------------------===// +constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; + static void print(OpAsmPrinter &p, ExecuteOp op) { - p << "async.execute "; - p.printRegion(op.body()); - p.printOptionalAttrDict(op.getAttrs()); - p << " : "; - p.printType(op.done().getType()); - if (!op.values().empty()) - p << ", "; - llvm::interleaveComma(op.values(), p, [&](const OpResult &result) { - p.printType(result.getType()); - }); + p << op.getOperationName(); + + // [%tokens,...] + if (!op.dependencies().empty()) + p << " [" << op.dependencies() << "]"; + + // (%value as %unwrapped: !async.value, ...) + if (!op.operands().empty()) { + p << " ("; + llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { + p << operand << " as " << op.body().front().getArgument(n++) << ": " + << operand.getType(); + }); + p << ")"; + } + + // -> (!async.value, ...) + p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1)); + p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr}); + p.printRegion(op.body(), /*printEntryBlockArgs=*/false); } static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { MLIRContext *ctx = result.getContext(); - // Parse asynchronous region. - Region *body = result.addRegion(); - if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}, - /*enableNameShadowing=*/false)) + // Sizes of parsed variadic operands, will be updated below after parsing. + int32_t numDependencies = 0; + int32_t numOperands = 0; + + auto tokenTy = TokenType::get(ctx); + + // Parse dependency tokens. + if (succeeded(parser.parseOptionalLSquare())) { + SmallVector tokenArgs; + if (parser.parseOperandList(tokenArgs) || + parser.resolveOperands(tokenArgs, tokenTy, result.operands) || + parser.parseRSquare()) + return failure(); + + numDependencies = tokenArgs.size(); + } + + // Parse async value operands (%value as %unwrapped : !async.value). + SmallVector valueArgs; + SmallVector unwrappedArgs; + SmallVector valueTypes; + SmallVector unwrappedTypes; + + if (succeeded(parser.parseOptionalLParen())) { + auto argsLoc = parser.getCurrentLocation(); + + // Parse a single instance of `%value as %unwrapped : !async.value`. + auto parseAsyncValueArg = [&]() -> ParseResult { + if (parser.parseOperand(valueArgs.emplace_back()) || + parser.parseKeyword("as") || + parser.parseOperand(unwrappedArgs.emplace_back()) || + parser.parseColonType(valueTypes.emplace_back())) + return failure(); + + auto valueTy = valueTypes.back().dyn_cast(); + unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); + + return success(); + }; + + // If the next token is `)` skip async value arguments parsing. + if (failed(parser.parseOptionalRParen())) { + do { + parseAsyncValueArg(); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen() || + parser.resolveOperands(valueArgs, valueTypes, argsLoc, + result.operands)) + return failure(); + } + + numOperands = valueArgs.size(); + } + + // Add derived `operand_segment_sizes` attribute based on parsed operands. + auto operandSegmentSizes = DenseIntElementsAttr::get( + VectorType::get({2}, parser.getBuilder().getI32Type()), + {numDependencies, numOperands}); + result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); + + // Parse the types of results returned from the async execute op. + SmallVector resultTypes; + if (parser.parseOptionalArrowTypeList(resultTypes)) return failure(); + // Async execute first result is always a completion token. + parser.addTypeToList(tokenTy, result.types); + parser.addTypesToList(resultTypes, result.types); + // Parse operation attributes. NamedAttrList attrs; - if (parser.parseOptionalAttrDict(attrs)) + if (parser.parseOptionalAttrDictWithKeyword(attrs)) return failure(); result.addAttributes(attrs); - // Parse result types. - SmallVector resultTypes; - if (parser.parseColonTypeList(resultTypes)) - return failure(); - - // First result type must be an async token type. - if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx)) + // Parse asynchronous region. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, + /*argTypes=*/{unwrappedTypes}, + /*enableNameShadowing=*/false)) return failure(); - parser.addTypesToList(resultTypes, result.types); return success(); } -} // namespace async -} // namespace mlir +static LogicalResult verify(ExecuteOp op) { + // Unwrap async.execute value operands types. + auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { + return operand.getType().cast().getValueType(); + }); + + // Verify that unwrapped argument types matches the body region arguments. + if (llvm::size(unwrappedTypes) != llvm::size(op.body().getArgumentTypes())) + return op.emitOpError("the number of async body region arguments does not " + "match the number of execute operation arguments"); + + if (!std::equal(unwrappedTypes.begin(), unwrappedTypes.end(), + op.body().getArgumentTypes().begin())) + return op.emitOpError("async body region argument types do not match the " + "execute operation arguments types"); + + return success(); +} #define GET_OP_CLASSES #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s | FileCheck %s +// RUN: mlir-opt %s | FileCheck %s // CHECK-LABEL: @identity_token -func @identity_token(%arg0 : !async.token) -> !async.token { +func @identity_token(%arg0: !async.token) -> !async.token { // CHECK: return %arg0 : !async.token return %arg0 : !async.token } @@ -14,33 +14,95 @@ // CHECK-LABEL: @empty_async_execute func @empty_async_execute() -> !async.token { - %done = async.execute { + // CHECK: async.execute + %token = async.execute { async.yield - } : !async.token + } - // CHECK: return %done : !async.token - return %done : !async.token + // CHECK: return %token : !async.token + return %token : !async.token } // CHECK-LABEL: @return_async_value func @return_async_value() -> !async.value { - %done, %values = async.execute { + // CHECK: async.execute -> !async.value + %token, %results = async.execute -> !async.value { %cst = constant 1.000000e+00 : f32 async.yield %cst : f32 - } : !async.token, !async.value + } - // CHECK: return %values : !async.value - return %values : !async.value + // CHECK: return %results : !async.value + return %results : !async.value +} + +// CHECK-LABEL: @return_captured_value +func @return_captured_value() -> !async.token { + %cst = constant 1.000000e+00 : f32 + // CHECK: async.execute -> !async.value + %token, %results = async.execute -> !async.value { + async.yield %cst : f32 + } + + // CHECK: return %token : !async.token + return %token : !async.token } // CHECK-LABEL: @return_async_values func @return_async_values() -> (!async.value, !async.value) { - %done, %values:2 = async.execute { + %token, %results:2 = async.execute -> (!async.value, !async.value) { %cst1 = constant 1.000000e+00 : f32 %cst2 = constant 2.000000e+00 : f32 async.yield %cst1, %cst2 : f32, f32 - } : !async.token, !async.value, !async.value + } + + // CHECK: return %results#0, %results#1 : !async.value, !async.value + return %results#0, %results#1 : !async.value, !async.value +} + +// CHECK-LABEL: @async_token_dependencies +func @async_token_dependencies(%arg0: !async.token) -> !async.token { + // CHECK: async.execute [%arg0] + %token = async.execute [%arg0] { + async.yield + } + + // CHECK: return %token : !async.token + return %token : !async.token +} + +// CHECK-LABEL: @async_value_operands +func @async_value_operands(%arg0: !async.value) -> !async.token { + // CHECK: async.execute (%arg0 as %arg1: !async.value) -> !async.value + %token, %results = async.execute (%arg0 as %arg1: !async.value) -> !async.value { + async.yield %arg1 : f32 + } + + // CHECK: return %token : !async.token + return %token : !async.token +} + +// CHECK-LABEL: @async_token_and_value_operands +func @async_token_and_value_operands(%arg0: !async.token, %arg1: !async.value) -> !async.token { + // CHECK: async.execute [%arg0] (%arg1 as %arg2: !async.value) -> !async.value + %token, %results = async.execute [%arg0] (%arg1 as %arg2: !async.value) -> !async.value { + async.yield %arg2 : f32 + } + + // CHECK: return %token : !async.token + return %token : !async.token +} - // CHECK: return %values#0, %values#1 : !async.value, !async.value - return %values#0, %values#1 : !async.value, !async.value +// CHECK-LABEL: @empty_tokens_or_values_operands +func @empty_tokens_or_values_operands() { + // CHECK: async.execute { + %token0 = async.execute [] () -> () { async.yield } + // CHECK: async.execute { + %token1 = async.execute () -> () { async.yield } + // CHECK: async.execute { + %token2 = async.execute -> () { async.yield } + // CHECK: async.execute { + %token3 = async.execute () { async.yield } + // CHECK: async.execute { + %token4 = async.execute [] { async.yield } + return }