diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -166,6 +166,15 @@ let assemblyFormat = "`to` $target $region attr-dict `:` type($target)"; let hasVerifier = 1; + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins + "Value":$target, + CArg<"function_ref", "nullptr">: + $bodyBuilder, + CArg<"bool", "true">:$failOnPayloadReplacementNotFound)>, + ]; + let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -477,6 +477,22 @@ transform::modifiesPayload(effects); } +void transform::ApplyPatternsOp::build( + OpBuilder &builder, OperationState &result, Value target, + function_ref bodyBuilder, + bool failOnPayloadReplacementNotFound) { + result.addOperands(target); + result.getOrAddProperties() + .fail_on_payload_replacement_not_found = + builder.getBoolAttr(failOnPayloadReplacementNotFound); + + OpBuilder::InsertionGuard g(builder); + Region *region = result.addRegion(); + builder.createBlock(region); + if (bodyBuilder) + bodyBuilder(builder, result.location); +} + //===----------------------------------------------------------------------===// // ApplyCanonicalizationPatternsOp //===----------------------------------------------------------------------===//