diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -78,6 +78,23 @@ builder.create(loc); } +/// Verifies that the first block of the given `region` is terminated by a +/// TerminatorTy. Reports errors on the given operation if it is not the case. +template +static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, + StringRef errorMessage) { + Operation *terminatorOperation = nullptr; + if (!region.empty() && !region.front().empty()) { + terminatorOperation = ®ion.front().back(); + if (auto yield = dyn_cast_or_null(terminatorOperation)) + return yield; + } + auto diag = op->emitOpError(errorMessage); + if (terminatorOperation) + diag.attachNote(terminatorOperation->getLoc()) << "terminator here"; + return nullptr; +} + //===----------------------------------------------------------------------===// // ExecuteRegionOp //===----------------------------------------------------------------------===// @@ -2323,10 +2340,13 @@ "expects arguments for the induction variable to be of index type"); // Check that the yield has no results - Operation *yield = body->getTerminator(); + auto yield = verifyAndGetTerminator( + *this, getRegion(), "expects body to terminate with 'scf.yield'"); + if (!yield) + return failure(); if (yield->getNumOperands() != 0) - return yield->emitOpError() << "not allowed to have operands inside '" - << ParallelOp::getOperationName() << "'"; + return yield.emitOpError() << "not allowed to have operands inside '" + << ParallelOp::getOperationName() << "'"; // Check that the number of results is the same as the number of ReduceOps. SmallVector reductions(body->getOps()); @@ -2854,23 +2874,6 @@ return success(); } -/// Verifies that the first block of the given `region` is terminated by a -/// YieldOp. Reports errors on the given operation if it is not the case. -template -static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region ®ion, - StringRef errorMessage) { - Operation *terminatorOperation = nullptr; - if (!region.empty() && !region.front().empty()) { - terminatorOperation = ®ion.front().back(); - if (auto yield = dyn_cast_or_null(terminatorOperation)) - return yield; - } - auto diag = op.emitOpError(errorMessage); - if (terminatorOperation) - diag.attachNote(terminatorOperation->getLoc()) << "terminator here"; - return nullptr; -} - LogicalResult scf::WhileOp::verify() { auto beforeTerminator = verifyAndGetTerminator( *this, getBefore(), diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -672,3 +672,16 @@ return }) {cases = array} : (index) -> () } + +// ----- + +func.func @parallel_missing_terminator(%0 : index) { + // expected-error @below {{'scf.parallel' op expects body to terminate with 'scf.yield'}} + "scf.parallel"(%0, %0, %0) ({ + ^bb0(%arg1: index): + // expected-note @below {{terminator here}} + %2 = "arith.constant"() {value = 1.000000e+00 : f32} : () -> f32 + }) {operand_segment_sizes = array} : (index, index, index) -> () + return +} +