diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -37,7 +37,7 @@ def ForOp : Loop_Op<"for", [DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp">, + SingleBlockImplicitTerminator<"LoopYieldOp">, RecursiveSideEffects]> { let summary = "for operation"; let description = [{ @@ -171,7 +171,7 @@ } def IfOp : Loop_Op<"if", - [SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> { + [SingleBlockImplicitTerminator<"LoopYieldOp">, RecursiveSideEffects]> { let summary = "if-then-else operation"; let description = [{ The "loop.if" operation represents an if-then-else construct for @@ -243,7 +243,7 @@ } def ParallelOp : Loop_Op<"parallel", - [AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> { + [AttrSizedOperandSegments, SingleBlockImplicitTerminator<"LoopYieldOp">]> { let summary = "parallel for operation"; let description = [{ The "loop.parallel" operation represents a loop nest taking 4 groups of SSA @@ -256,7 +256,7 @@ The lower and upper bounds specify a half-open range: the range includes the lower bound but does not include the upper bound. The initial values have the same types as results of "loop.parallel". If there are no results, - the keyword `init` can be omitted. + the keyword `init` can be omitted. Semantically we require that the iteration space can be iterated in any order, and the loop body can be executed in parallel. If there are data @@ -381,7 +381,7 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } -def YieldOp : Loop_Op<"yield", [NoSideEffect, Terminator]> { +def LoopYieldOp : Loop_Op<"yield", [NoSideEffect, Terminator]> { let summary = "loop yield and termination operation"; let description = [{ "loop.yield" yields an SSA value from a loop dialect op region and diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -923,6 +923,73 @@ let summary = "base-e exponential of the specified value"; } +//===----------------------------------------------------------------------===// +// ExecuteRegionOp +//===----------------------------------------------------------------------===// + +def ExecuteRegionOp : Std_Op<"execute_region"> { + let summary = "operation that executes its region exactly once"; + let description = [{ + The `execute_region` operation executes the region held exactly once. The op + cannot have any operands, nor does its region have any arguments. All SSA + values that dominate the op can be accessed inside. The op's region can have + multiple blocks and the blocks can have terminators the same way as FuncOp. + The values returned from this op's region define the op's results. The op + primarily provides control flow encapsulation and isolation from a parent + op's control flow restrictions if any; for example, it allows representation + of inlined calls in the inside of structured control flow ops with + restrictions like affine.for/if, loop.for/if ops, and thus the optimization + of IR in such a mixed form. + + The operation uses std.yield as its terminator - using return instead as the + terminator is under discussion D71961. + + Example: + + ```mlir + loop.for %i = 0 to 128 { + %y = execute_region -> i32 { + %x = load %A[%i] : memref<128xi32> + yield %x : i32 + } + } + + affine.for %i = 0 to 100 { + "foo"() : () -> () + %v = execute_region -> i64 { + cond_br %cond, ^bb1, ^bb2 + + ^bb1: + %c1 = constant 1 : i64 + br ^bb3(%c1 : i64) + + ^bb2: + %c2 = constant 2 : i64 + br ^bb3(%c2 : i64) + + ^bb3(%x : i64): + yield %x : i64 + } + "bar"(%v) : (i64) -> () + } + ``` + }]; + + let results = (outs Variadic); + + let regions = (region AnyRegion:$region); + + // TODO: If the parent is a func like op (which would be the case if all other + // ops are from the std dialect), the inliner logic could be readily used to + // inline. + let hasCanonicalizer = 0; + + // TODO: can fold if it returns a constant. + // TODO: Single block execute_region ops can be readily inlined irrespective + // of which op is a parent. Add a fold for this. + let hasFolder = 0; +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// @@ -2053,6 +2120,37 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +def YieldOp : Std_Op<"yield", [NoSideEffect, Terminator]> { + let summary = "yield operation"; + let description = [{ + The "yield" operation represents a transfer of control flow to its parent + operation with its operands being values to be sent out from the region. The + operation takes a variable number of operands and produces no results. The + operand count and types must match the results of the parent operation. For + example: + + ```mlir + %res = execute_region -> (i32, f8) { + ... + yield %0, %1 : i32, f8 + } + ``` + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} + + //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -332,7 +332,7 @@ LogicalResult matchAndRewrite(AffineTerminatorOp op, PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op); + rewriter.replaceOpWithNewOp(op); return success(); } }; diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -42,13 +42,13 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Pattern to erase a loop::YieldOp. -class TerminatorOpConversion final : public SPIRVOpLowering { +/// Pattern to erase a LoopYieldOp. +class TerminatorOpConversion final : public SPIRVOpLowering { public: - using SPIRVOpLowering::SPIRVOpLowering; + using SPIRVOpLowering::SPIRVOpLowering; LogicalResult - matchAndRewrite(loop::YieldOp terminatorOp, ArrayRef operands, + matchAndRewrite(loop::LoopYieldOp terminatorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(terminatorOp); return success(); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -338,7 +338,8 @@ class YieldOpConversion : public ConvertToLLVMPattern { public: explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context, + lowering_) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, diff --git a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp --- a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp @@ -304,7 +304,7 @@ // terminator is the last operation in the block because further transfoms // rely on this. rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( rewriter.getInsertionBlock()->getTerminator(), forOp.getResults()); } @@ -339,7 +339,7 @@ } rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( rewriter.getInsertionBlock()->getTerminator(), yieldOperands); rewriter.replaceOp(parallelOp, loopResults); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -726,7 +726,7 @@ // YieldOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, YieldOp op) { +static void print(OpAsmPrinter &p, linalg::YieldOp op) { p << op.getOperationName(); if (op.getNumOperands() > 0) p << ' ' << op.getOperands(); @@ -746,7 +746,7 @@ } template -static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { +static LogicalResult verifyYield(linalg::YieldOp op, GenericOpType genericOp) { // The operand number and types must match the view element types. auto nOutputs = genericOp.getNumOutputs(); if (op.getNumOperands() != nOutputs) @@ -766,7 +766,7 @@ return success(); } -static LogicalResult verify(YieldOp op) { +static LogicalResult verify(linalg::YieldOp op) { auto *parentOp = op.getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) return op.emitOpError("expected single non-empty parent region"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -98,7 +98,7 @@ } Operation &terminator = block.back(); - assert(isa(terminator) && + assert(isa(terminator) && "expected an yield op in the end of the region"); for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { std_store(map.lookup(terminator.getOperand(i)), outputBuffers[i], diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -135,10 +135,10 @@ auto c = m_Val(r.front().getArgument(2)); // TODO(ntv) Update this detection once we have matcher support for // specifying that any permutation of operands matches. - auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); - auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); - auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); - auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); + auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); + auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); + auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); + auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); return pattern1.match(&ops.back()) || pattern2.match(&ops.back()) || pattern3.match(&ops.back()) || pattern4.match(&ops.back()); } diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -548,7 +548,7 @@ //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(YieldOp op) { +static LogicalResult verify(LoopYieldOp op) { auto parentOp = op.getParentOp(); auto results = parentOp->getResults(); auto operands = op.getOperands(); @@ -574,7 +574,8 @@ return success(); } -static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { +static ParseResult parseLoopYieldOp(OpAsmParser &parser, + OperationState &result) { SmallVector operands; SmallVector types; llvm::SMLoc loc = parser.getCurrentLocation(); @@ -587,7 +588,7 @@ return success(); } -static void print(OpAsmPrinter &p, YieldOp op) { +static void print(OpAsmPrinter &p, LoopYieldOp op) { p << op.getOperationName(); if (op.getNumOperands() != 0) p << ' ' << op.getOperands() << " : " << op.getOperandTypes(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1381,6 +1381,61 @@ return {}; } +//===----------------------------------------------------------------------===// +// ExecuteRegionOp +//===----------------------------------------------------------------------===// + +/// +/// (ssa-id `=`)? `execute_region` `->` function-result-type `{` +/// block+ +/// `}` +/// +/// Example: +/// std.execute_region -> i32 { +/// %idx = load %rI[%i] : memref<128xi32> +/// return %idx : i32 +/// } +/// +static ParseResult parseExecuteRegionOp(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Introduce the body region and parse it. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, ExecuteRegionOp op) { + p << ExecuteRegionOp::getOperationName(); + if (op.getNumResults() > 0) + p << " -> " << op.getResultTypes(); + + p.printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + + p.printOptionalAttrDict(op.getAttrs()); +} + +static LogicalResult verify(ExecuteRegionOp op) { + if (op.region().empty()) + return op.emitOpError("region needs to have at least one block"); + + if (op.region().front().getNumArguments() > 0) + return op.emitOpError("region cannot have any arguments"); + + return success(); +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// @@ -2491,6 +2546,28 @@ [](APInt a, APInt b) { return a ^ b; }); } +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(YieldOp op) { + // The operand number and types must match those of parent op's results'. + auto results = op.getParentOp()->getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError("has ") + << op.getNumOperands() + << " operands, but enclosing function returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (op.getOperand(i).getType() != results[i].getType()) + return op.emitError() << "type of yield operand " << i << " (" + << op.getOperand(i).getType() + << ") doesn't match parent op's result type (" + << results[i].getType() << ")"; + + return success(); +} + //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -591,6 +591,27 @@ // CHECK: %4 = call_indirect %f_0(%arg0) : (i32) -> i32 %3 = "std.call_indirect"(%f_0, %arg0) : ((i32) -> i32, i32) -> i32 + // CHECK: execute_region -> i64 { + // CHECK-NEXT: constant + // CHECK-NEXT: yield + // CHECK-NEXT: } + %4 = execute_region -> i64 { + %c1 = constant 1 : i64 + yield %c1 : i64 + } + + // CHECK: execute_region { + // CHECK-NEXT: br ^bb1 + // CHECK-NEXT: ^bb1: + // CHECK-NEXT: yield + // CHECK-NEXT: } + "std.execute_region"() ({ + ^bb0: + br ^bb1 + ^bb1: + yield + }) : () -> () + return } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1231,3 +1231,14 @@ // expected-error @+1 {{expected unsigned integer elements, but parsed negative value}} "foo"() {bar = dense<[5, -5]> : vector<2xui32>} : () -> () } + +// ----- + +func @execute_region() { + // expected-error @+1 {{region cannot have any arguments}} + "std.execute_region"() ({ + ^bb0(%i : i32): + return + }) : () -> () + return +} diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -491,6 +491,33 @@ return } +// CHECK-LABEL: func @propagate_into_execute_region +func @propagate_into_execute_region() { + %cond = constant 0 : i1 + affine.for %i = 0 to 100 { + "foo"() : () -> () + %v = execute_region -> i64 { + cond_br %cond, ^bb1, ^bb2 + + ^bb1: + %c1 = constant 1 : i64 + br ^bb3(%c1 : i64) + + ^bb2: + %c2 = constant 2 : i64 + br ^bb3(%c2 : i64) + + ^bb3(%x : i64): + yield %x : i64 + } + "bar"(%v) : (i64) -> () + // CHECK: std.execute_region -> i64 { + // CHECK-NEXT: yield %c2_i64 : i64 + // CHECK-NEXT: } + } + return +} + // CHECK-LABEL: func @const_fold_propagate func @const_fold_propagate() -> memref { %VT_i = constant 512 : index