diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -21,6 +21,10 @@ namespace transform { enum class FailurePropagationMode : uint32_t; class FailurePropagationModeAttr; + +/// A builder function that populates the body of a SequenceOp. +using SequenceBodyBuilderFn = ::llvm::function_ref; } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -389,6 +389,10 @@ IR, typically the root operation of the pass interpreting the transform dialect. Operand omission is only allowed for sequences not contained in another sequence. + + The body of the sequence terminates with an implicit or explicit + `transform.yield` op. The operands of the terminator are returned as the + results of the sequence op. }]; let arguments = (ins FailurePropagationMode:$failure_propagation_mode, @@ -400,6 +404,20 @@ "($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` " "$failure_propagation_mode `)` attr-dict-with-keyword regions"; + let builders = [ + // Build a sequence with a root. + OpBuilder<(ins + "::mlir::TypeRange":$resultTypes, + "::mlir::transform::FailurePropagationMode":$failure_propagation_mode, + "::mlir::Value":$root, "SequenceBodyBuilderFn":$bodyBuilder)>, + + // Build a sequence without a root but a certain bbArg type. + OpBuilder<(ins + "::mlir::TypeRange":$resultTypes, + "::mlir::transform::FailurePropagationMode":$failure_propagation_mode, + "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)> + ]; + let extraClassDeclaration = [{ /// Allow the dialect prefix to be omitted. static StringRef getDefaultDialect() { return "transform"; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -765,6 +765,39 @@ bounds.emplace_back(1, 1); } +void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, + TypeRange resultTypes, + FailurePropagationMode failurePropagationMode, + Value root, + SequenceBodyBuilderFn bodyBuilder) { + build(builder, state, resultTypes, failurePropagationMode, root); + Region *region = state.regions.back().get(); + auto bbArgType = root.getType(); + Block *bodyBlock = builder.createBlock( + region, region->begin(), TypeRange{bbArgType}, {state.location}); + + // Populate body. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(bodyBlock); + bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); +} + +void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, + TypeRange resultTypes, + FailurePropagationMode failurePropagationMode, + Type bbArgType, + SequenceBodyBuilderFn bodyBuilder) { + build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value()); + Region *region = state.regions.back().get(); + Block *bodyBlock = builder.createBlock( + region, region->begin(), TypeRange{bbArgType}, {state.location}); + + // Populate body. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(bodyBlock); + bodyBuilder(builder, state.location, bodyBlock->getArgument(0)); +} + //===----------------------------------------------------------------------===// // WithPDLPatternsOp //===----------------------------------------------------------------------===//