diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -417,7 +417,8 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } -def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator]> { +def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator, + ParentOneOf<["IfOp, ForOp", "ParallelOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ "scf.yield" yields an SSA value from the SCF dialect op region and @@ -436,5 +437,7 @@ OpBuilder<"OpBuilder &builder, OperationState &result", [{ /* nothing to do */ }]> ]; + // no custom verification needed + let verifier = [{}]; } #endif // MLIR_DIALECT_SCF_SCFOPS diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -18,6 +18,7 @@ namespace mlir { class BranchOpInterface; +class RegionBranchOpInterface; //===----------------------------------------------------------------------===// // BranchOpInterface @@ -40,12 +41,21 @@ // RegionBranchOpInterface //===----------------------------------------------------------------------===// +namespace detail { +// Verify that types match along control flow edges described the given op. +LogicalResult verifyTypesAlongControlFlowEdges(Operation *op); +} // namespace detail + /// This class represents a successor of a region. A region successor can either /// be another region, or the parent operation. If the successor is a region, -/// this class accepts the destination region, as well as a set of arguments +/// this class represents the destination region, as well as a set of arguments /// from that region that will be populated by values from the current region. -/// If the successor is the parent operation, this class accepts an optional set -/// of results that will be populated by values from the current region. +/// If the successor is the parent operation, this class represents an optional +/// set of results that will be populated by values from the current region. +/// +/// This interface assumes that the values from the current region that are used +/// to populate the successor inputs are the operands of the return-like +/// terminator operations in the blocks within this region class RegionSuccessor { public: /// Initialize a successor that branches to another region of the parent @@ -61,6 +71,9 @@ /// parent operation. Region *getSuccessor() const { return region; } + /// Return true if the successor is the parent operation + bool isParent() const { return region == nullptr; } + /// Return the inputs to the successor that are remapped by the exit values of /// the current region. ValueRange getSuccessorInputs() const { return inputs; } diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -99,9 +99,9 @@ let methods = [ InterfaceMethod<[{ Returns the operands of this operation used as the entry arguments when - entering the region at `index`, which was specified as a successor by - `getSuccessorRegions`. These operands should correspond 1-1 with the - successor inputs specified in `getSuccessorRegions`, and may corre + entering the region at `index`, which was specified as a successor of this + operation by `getSuccessorRegions`. These operands should correspond 1-1 + with the successor inputs specified in `getSuccessorRegions`. }], "OperandRange", "getSuccessorEntryOperands", (ins "unsigned":$index), [{}], /*defaultImplementation=*/[{ @@ -128,6 +128,13 @@ "SmallVectorImpl &":$regions) > ]; + + let extraClassDeclaration = [{ + /// Verify types along control flow edges described by this interface. + static LogicalResult verifyTypes(mlir::Operation *op) { + return detail::verifyTypesAlongControlFlowEdges(op); + } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -137,6 +137,10 @@ i++; } + + if (failed(RegionBranchOpInterface::verifyTypes(op))) + return failure(); + return success(); } @@ -423,6 +427,9 @@ if (op.getNumResults() != 0 && op.elseRegion().empty()) return op.emitOpError("must have an else block if defining values"); + if (failed(RegionBranchOpInterface::verifyTypes(op))) + return failure(); + return success(); } @@ -504,10 +511,12 @@ elseRegion = nullptr; // Otherwise, the successor is dependent on the condition. - bool condition; - if (auto condAttr = operands.front().dyn_cast_or_null()) { - condition = condAttr.getValue().isOneValue(); - } else { + Optional condition; + if (!operands.empty()) + if (auto condAttr = operands.front().dyn_cast_or_null()) + condition = condAttr.getValue().isOneValue(); + + if (!condition.hasValue()) { // If the condition isn't constant, both regions may be executed. regions.push_back(RegionSuccessor(&thenRegion())); regions.push_back(RegionSuccessor(elseRegion)); @@ -602,6 +611,12 @@ return op.emitOpError( "expects arguments for the induction variable to be of index type"); + // Check that the yield has no results + Operation *yield = body->getTerminator(); + if (yield->getNumOperands() != 0) + return yield->emitOpError() + << "yield inside scf.parallel is not allowed to have operands"; + // Check that the number of results is the same as the number of ReduceOps. SmallVector reductions(body->getOps()); auto resultsSize = op.results().size(); @@ -876,35 +891,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// YieldOp -//===----------------------------------------------------------------------===// -static LogicalResult verify(YieldOp op) { - auto parentOp = op.getParentOp(); - auto results = parentOp->getResults(); - auto operands = op.getOperands(); - - if (isa(parentOp)) { - if (parentOp->getNumResults() != op.getNumOperands()) - return op.emitOpError() << "parent of yield must have same number of " - "results as the yield operands"; - for (auto e : llvm::zip(results, operands)) { - if (std::get<0>(e).getType() != std::get<1>(e).getType()) - return op.emitOpError() - << "types mismatch between yield op and its parent"; - } - } else if (isa(parentOp)) { - if (op.getNumOperands() != 0) - return op.emitOpError() - << "yield inside scf.parallel is not allowed to have operands"; - } else { - return op.emitOpError() - << "yield only terminates If, For or Parallel regions"; - } - - return success(); -} - static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { SmallVector operands; SmallVector types; diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -8,6 +8,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/SmallPtrSet.h" using namespace mlir; @@ -24,8 +25,9 @@ /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if 'operandIndex' is within the range of 'operands', or None if /// `operandIndex` isn't a successor operand index. -Optional mlir::detail::getBranchSuccessorArgument( - Optional operands, unsigned operandIndex, Block *successor) { +Optional +detail::getBranchSuccessorArgument(Optional operands, + unsigned operandIndex, Block *successor) { // Check that the operands are valid. if (!operands || operands->empty()) return llvm::None; @@ -43,8 +45,8 @@ /// Verify that the given operands match those of the given successor block. LogicalResult -mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, - Optional operands) { +detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, + Optional operands) { if (!operands) return success(); @@ -66,3 +68,122 @@ } return success(); } + +//===----------------------------------------------------------------------===// +// RegionBranchOpInterface +//===----------------------------------------------------------------------===// + +// Verify that types match along all region control flow edges originating from +// `sourceIndex` (region # is source is a region, llvm::None if source is parent +// op). `getInputsTypesForRegion` is a function that returns the types of the +// inputs that flow from `sourceIndex' to the given region. +static LogicalResult verifyTypesAlongAllSuccessors( + Operation *op, Optional sourceIndex, + function_ref)> getInputsTypesForRegion) { + auto regionInterface = cast(op); + + SmallVector successors; + regionInterface.getSuccessorRegions(sourceIndex, {}, successors); + + for (auto &succ : successors) { + Optional succRegionNo = + succ.isParent() ? Optional() + : succ.getSuccessor()->getRegionNumber(); + + auto get_edge_name = [&]() -> std::string { + return (Twine("from ") + + (sourceIndex.hasValue() + ? Twine("Region #") + Twine(sourceIndex.getValue()) + : Twine(op->getName().getStringRef())) + + " to " + + (succRegionNo.hasValue() + ? Twine("Region #") + Twine(succRegionNo.getValue()) + : Twine(op->getName().getStringRef()))) + .str(); + }; + + TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo); + TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); + if (sourceTypes.size() != succInputsTypes.size()) + return op->emitOpError(": region control flow edge ") + << get_edge_name() << " has " << sourceTypes.size() + << " source operands, but target successor needs " + << succInputsTypes.size(); + + for (auto types_idx : + llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) { + Type SourceType = std::get<0>(types_idx.value()); + Type InputType = std::get<1>(types_idx.value()); + if (SourceType != InputType) + return op->emitOpError(": along control flow edge ") + << get_edge_name() << " source #" << types_idx.index() + << " type " << SourceType << " should match input #" + << types_idx.index() << " type " << InputType; + } + } + return success(); +} + +// Verify that types match along control flow edges described the given op. +LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { + auto regionInterface = cast(op); + + // Verify types along control flow edges originating from the parent. + verifyTypesAlongAllSuccessors( + op, llvm::None, [&](Optional regionIndex) -> TypeRange { + if (regionIndex.hasValue()) + return regionInterface + .getSuccessorEntryOperands(regionIndex.getValue()) + .getTypes(); + + // If the successor of a parent op is the parent itself + // RegionBranchOpInterface does not have an API to query what the entry + // operands will be in that case. Vend out the result types of the op in + // that case so that type checking succeeds for this case + return op->getResultTypes(); + }); + + if (op->getNumRegions() == 0) + return op->emitOpError() + << "implements RegionBranchOpInterface but has no regions"; + + // Verify types along control flow edges originating from each region. + for (unsigned regionIndex : llvm::seq(0U, op->getNumRegions())) { + Region ®ion = op->getRegion(regionIndex); + + // Since the interface cannnot distinguish between different ReturnLike + // ops within the region branching to different successors, all ReturnLike + // ops in this region should have the same operand types. We will then use + // one of them as the representative for type matching. + + Operation *regionReturn = nullptr; + for (auto &block : region) { + Operation *terminator = block.getTerminator(); + if (!terminator->hasTrait()) + continue; + + if (!regionReturn) { + regionReturn = terminator; + continue; + } + + // Found more than one ReturnLike terminator. Make sure the operand types + // match with the first one + if (regionReturn->getOperandTypes() != terminator->getOperandTypes()) + return op->emitOpError("Region #") + << regionIndex + << " operands mismatch between return-like terminators"; + } + + verifyTypesAlongAllSuccessors( + op, regionIndex, [&](Optional) -> TypeRange { + // all successors get the same set of operands + return regionReturn + ? TypeRange(regionReturn->getOperands().getTypes()) + : TypeRange( + OperandRange(op->operand_end(), op->operand_end())); + }); + } + + return success(); +} diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -325,13 +325,13 @@ func @std_if_incorrect_yield(%arg0: i1, %arg1: f32) { + // expected-error@+1 {{region control flow edge from Region #0 to scf.if has 1 source operands, but target successor needs 2}} %x, %y = scf.if %arg0 -> (f32, f32) { %0 = addf %arg1, %arg1 : f32 - // expected-error@+1 {{parent of yield must have same number of results as the yield operands}} scf.yield %0 : f32 } else { %0 = subf %arg1, %arg1 : f32 - scf.yield %0 : f32 + scf.yield %0, %0 : f32, f32 } return } @@ -396,6 +396,22 @@ return } +// ----- + +func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) { + %s0 = constant 0.0 : f32 + %t0 = constant 1.0 : f32 + // expected-error @+1 {{ along control flow edge from Region #0 to Region #0 source #1 type 'i32' should match input #1 type 'f32'}} + %result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 + iter_args(%si = %s0, %ti = %t0) -> (f32, f32) { + %sn = addf %si, %si : f32 + %ic = constant 1 : i32 + scf.yield %sn, %ic : f32, i32 + } + return +} + + // ----- func @parallel_invalid_yield( @@ -407,3 +423,12 @@ } return } + +// ----- +func @yield_invalid_parent_op() { + "my.op"() ({ + // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel'}} + scf.yield + }) : () -> () + return +}