diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -418,7 +418,8 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } -def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator]> { +def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator, + ParentOneOf<["IfOp, ForOp", "ParallelOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ "scf.yield" yields an SSA value from the SCF dialect op region and @@ -437,5 +438,8 @@ OpBuilder<"OpBuilder &builder, OperationState &result", [{ /* nothing to do */ }]> ]; + // Override default verifier (defined in SCF_Op), no custom verification + // needed. + let verifier = ?; } #endif // MLIR_DIALECT_SCF_SCFOPS diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -18,6 +18,7 @@ namespace mlir { class BranchOpInterface; +class RegionBranchOpInterface; //===----------------------------------------------------------------------===// // BranchOpInterface @@ -40,12 +41,21 @@ // RegionBranchOpInterface //===----------------------------------------------------------------------===// +namespace detail { +/// Verify that types match along control flow edges described the given op. +LogicalResult verifyTypesAlongControlFlowEdges(Operation *op); +} // namespace detail + /// This class represents a successor of a region. A region successor can either /// be another region, or the parent operation. If the successor is a region, -/// this class accepts the destination region, as well as a set of arguments +/// this class represents the destination region, as well as a set of arguments /// from that region that will be populated by values from the current region. -/// If the successor is the parent operation, this class accepts an optional set -/// of results that will be populated by values from the current region. +/// If the successor is the parent operation, this class represents an optional +/// set of results that will be populated by values from the current region. +/// +/// This interface assumes that the values from the current region that are used +/// to populate the successor inputs are the operands of the return-like +/// terminator operations in the blocks within this region. class RegionSuccessor { public: /// Initialize a successor that branches to another region of the parent @@ -61,6 +71,9 @@ /// parent operation. Region *getSuccessor() const { return region; } + /// Return true if the successor is the parent operation. + bool isParent() const { return region == nullptr; } + /// Return the inputs to the successor that are remapped by the exit values of /// the current region. ValueRange getSuccessorInputs() const { return inputs; } diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -103,9 +103,9 @@ let methods = [ InterfaceMethod<[{ Returns the operands of this operation used as the entry arguments when - entering the region at `index`, which was specified as a successor by - `getSuccessorRegions`. These operands should correspond 1-1 with the - successor inputs specified in `getSuccessorRegions`, and may corre + entering the region at `index`, which was specified as a successor of this + operation by `getSuccessorRegions`. These operands should correspond 1-1 + with the successor inputs specified in `getSuccessorRegions`. }], "OperandRange", "getSuccessorEntryOperands", (ins "unsigned":$index), [{}], /*defaultImplementation=*/[{ @@ -132,6 +132,19 @@ "SmallVectorImpl &":$regions) > ]; + + let verify = [{ + static_assert(!ConcreteOpType::template hasTrait(), + "expected operation to have non-zero regions"); + return success(); + }]; + + let extraClassDeclaration = [{ + /// Verify types along control flow edges described by this interface. + static LogicalResult verifyTypes(Operation *op) { + return detail::verifyTypesAlongControlFlowEdges(op); + } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -137,7 +137,8 @@ i++; } - return success(); + + return RegionBranchOpInterface::verifyTypes(op); } static void print(OpAsmPrinter &p, ForOp op) { @@ -413,7 +414,7 @@ if (op.getNumResults() != 0 && op.elseRegion().empty()) return op.emitOpError("must have an else block if defining values"); - return success(); + return RegionBranchOpInterface::verifyTypes(op); } static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) { @@ -592,6 +593,12 @@ return op.emitOpError( "expects arguments for the induction variable to be of index type"); + // Check that the yield has no results + Operation *yield = body->getTerminator(); + if (yield->getNumOperands() != 0) + return yield->emitOpError() << "not allowed to have operands inside '" + << ParallelOp::getOperationName() << "'"; + // Check that the number of results is the same as the number of ReduceOps. SmallVector reductions(body->getOps()); auto resultsSize = op.results().size(); @@ -869,31 +876,6 @@ //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(YieldOp op) { - auto parentOp = op.getParentOp(); - auto results = parentOp->getResults(); - auto operands = op.getOperands(); - - if (isa(parentOp)) { - if (parentOp->getNumResults() != op.getNumOperands()) - return op.emitOpError() << "parent of yield must have same number of " - "results as the yield operands"; - for (auto e : llvm::zip(results, operands)) { - if (std::get<0>(e).getType() != std::get<1>(e).getType()) - return op.emitOpError() - << "types mismatch between yield op and its parent"; - } - } else if (isa(parentOp)) { - if (op.getNumOperands() != 0) - return op.emitOpError() - << "yield inside scf.parallel is not allowed to have operands"; - } else { - return op.emitOpError() - << "yield only terminates If, For or Parallel regions"; - } - - return success(); -} static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { SmallVector operands; diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -8,6 +8,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/SmallPtrSet.h" using namespace mlir; @@ -24,8 +25,9 @@ /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if 'operandIndex' is within the range of 'operands', or None if /// `operandIndex` isn't a successor operand index. -Optional mlir::detail::getBranchSuccessorArgument( - Optional operands, unsigned operandIndex, Block *successor) { +Optional +detail::getBranchSuccessorArgument(Optional operands, + unsigned operandIndex, Block *successor) { // Check that the operands are valid. if (!operands || operands->empty()) return llvm::None; @@ -43,8 +45,8 @@ /// Verify that the given operands match those of the given successor block. LogicalResult -mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, - Optional operands) { +detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, + Optional operands) { if (!operands) return success(); @@ -66,3 +68,139 @@ } return success(); } + +//===----------------------------------------------------------------------===// +// RegionBranchOpInterface +//===----------------------------------------------------------------------===// + +/// Verify that types match along all region control flow edges originating from +/// `sourceNo` (region # if source is a region, llvm::None if source is parent +/// op). `getInputsTypesForRegion` is a function that returns the types of the +/// inputs that flow from `sourceIndex' to the given region. +static LogicalResult verifyTypesAlongAllEdges( + Operation *op, Optional sourceNo, + function_ref)> getInputsTypesForRegion) { + auto regionInterface = cast(op); + + SmallVector successors; + unsigned numInputs; + if (sourceNo) { + Region &srcRegion = op->getRegion(sourceNo.getValue()); + numInputs = srcRegion.getNumArguments(); + } else { + numInputs = op->getNumOperands(); + } + SmallVector operands(numInputs, nullptr); + regionInterface.getSuccessorRegions(sourceNo, operands, successors); + + for (RegionSuccessor &succ : successors) { + Optional succRegionNo; + if (!succ.isParent()) + succRegionNo = succ.getSuccessor()->getRegionNumber(); + + auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { + diag << "from "; + if (sourceNo) + diag << "Region #" << sourceNo.getValue(); + else + diag << op->getName(); + + diag << " to "; + if (succRegionNo) + diag << "Region #" << succRegionNo.getValue(); + else + diag << op->getName(); + return diag; + }; + + TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo); + TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); + if (sourceTypes.size() != succInputsTypes.size()) { + InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); + return printEdgeName(diag) + << " has " << sourceTypes.size() + << " source operands, but target successor needs " + << succInputsTypes.size(); + } + + for (auto typesIdx : + llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) { + Type sourceType = std::get<0>(typesIdx.value()); + Type inputType = std::get<1>(typesIdx.value()); + if (sourceType != inputType) { + InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); + return printEdgeName(diag) + << " source #" << typesIdx.index() << " type " << sourceType + << " should match input #" << typesIdx.index() << " type " + << inputType; + } + } + } + return success(); +} + +/// Verify that types match along control flow edges described the given op. +LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { + auto regionInterface = cast(op); + + auto inputTypesFromParent = [&](Optional regionNo) -> TypeRange { + if (regionNo.hasValue()) { + return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) + .getTypes(); + } + + // If the successor of a parent op is the parent itself + // RegionBranchOpInterface does not have an API to query what the entry + // operands will be in that case. Vend out the result types of the op in + // that case so that type checking succeeds for this case. + return op->getResultTypes(); + }; + + // Verify types along control flow edges originating from the parent. + if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent))) + return failure(); + + // RegionBranchOpInterface should not be implemented by Ops that do not have + // attached regions. + assert(op->getNumRegions() != 0); + + // Verify types along control flow edges originating from each region. + for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { + Region ®ion = op->getRegion(regionNo); + + // Since the interface cannnot distinguish between different ReturnLike + // ops within the region branching to different successors, all ReturnLike + // ops in this region should have the same operand types. We will then use + // one of them as the representative for type matching. + + Operation *regionReturn = nullptr; + for (Block &block : region) { + Operation *terminator = block.getTerminator(); + if (!terminator->hasTrait()) + continue; + + if (!regionReturn) { + regionReturn = terminator; + continue; + } + + // Found more than one ReturnLike terminator. Make sure the operand types + // match with the first one. + if (regionReturn->getOperandTypes() != terminator->getOperandTypes()) + return op->emitOpError("Region #") + << regionNo + << " operands mismatch between return-like terminators"; + } + + auto inputTypesFromRegion = [&](Optional regionNo) -> TypeRange { + // All successors get the same set of operands. + return regionReturn ? TypeRange(regionReturn->getOperands().getTypes()) + : TypeRange(); + }; + + if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) + return failure(); + } + + return success(); +} diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -325,13 +325,13 @@ func @std_if_incorrect_yield(%arg0: i1, %arg1: f32) { + // expected-error@+1 {{region control flow edge from Region #0 to scf.if has 1 source operands, but target successor needs 2}} %x, %y = scf.if %arg0 -> (f32, f32) { %0 = addf %arg1, %arg1 : f32 - // expected-error@+1 {{parent of yield must have same number of results as the yield operands}} scf.yield %0 : f32 } else { %0 = subf %arg1, %arg1 : f32 - scf.yield %0 : f32 + scf.yield %0, %0 : f32, f32 } return } @@ -396,14 +396,39 @@ return } +// ----- + +func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) { + %s0 = constant 0.0 : f32 + %t0 = constant 1.0 : f32 + // expected-error @+1 {{along control flow edge from Region #0 to Region #0 source #1 type 'i32' should match input #1 type 'f32'}} + %result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 + iter_args(%si = %s0, %ti = %t0) -> (f32, f32) { + %sn = addf %si, %si : f32 + %ic = constant 1 : i32 + scf.yield %sn, %ic : f32, i32 + } + return +} + + // ----- func @parallel_invalid_yield( %arg0: index, %arg1: index, %arg2: index) { scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { %c0 = constant 1.0 : f32 - // expected-error@+1 {{yield inside scf.parallel is not allowed to have operands}} + // expected-error@+1 {{'scf.yield' op not allowed to have operands inside 'scf.parallel'}} scf.yield %c0 : f32 } return } + +// ----- +func @yield_invalid_parent_op() { + "my.op"() ({ + // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel'}} + scf.yield + }) : () -> () + return +}