diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -645,13 +645,22 @@ } LogicalResult transform::SequenceOp::verify() { + assert(getBodyBlock()->getNumArguments() == 1 && + "the number of arguments must have been verified to be 1 by " + "PossibleTopLevelTransformOpTrait"); + + BlockArgument arg = getBodyBlock()->getArgument(0); + if (getRoot()) { + if (arg.getType() != getRoot().getType()) { + return emitOpError() << "expects the type of the block argument to match " + "the type of the operand"; + } + } + // Check if the block argument has more than one consuming use. - for (BlockArgument argument : getBodyBlock()->getArguments()) { - auto report = [&]() { - return (emitOpError() << "block argument #" << argument.getArgNumber()); - }; - if (failed(checkDoubleConsume(argument, report))) - return failure(); + if (failed(checkDoubleConsume( + arg, [this]() { return (emitOpError() << "block argument #0"); }))) { + return failure(); } // Check properties of the nested operations they cannot check themselves. @@ -765,12 +774,12 @@ SequenceBodyBuilderFn bodyBuilder) { build(builder, state, resultTypes, failurePropagationMode, root); Region *region = state.regions.back().get(); - auto bbArgType = root.getType(); + Type bbArgType = root.getType(); + OpBuilder::InsertionGuard guard(builder); Block *bodyBlock = builder.createBlock( region, region->begin(), TypeRange{bbArgType}, {state.location}); // Populate body. - OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(bodyBlock); bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); } @@ -782,11 +791,11 @@ SequenceBodyBuilderFn bodyBuilder) { build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value()); Region *region = state.regions.back().get(); + OpBuilder::InsertionGuard guard(builder); Block *bodyBlock = builder.createBlock( region, region->begin(), TypeRange{bbArgType}, {state.location}); // Populate body. - OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(bodyBlock); bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); } diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -35,6 +35,17 @@ // ----- +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expects the type of the block argument to match the type of the operand}} + transform.sequence %arg0: !transform.any_op failures(propagate) { + ^bb1(%arg1: !pdl.operation): + transform.yield + } +} + +// ----- + // expected-note @below {{nested in another possible top-level op}} transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation):