diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -101,6 +101,31 @@ let hasVerifier = 1; } +def AnnotateOp : TransformDialectOp<"annotate", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Annotates the target operation with an attribute by name"; + let description = [{ + Adds an attribute with the given `name` to the `target` operation. An optional + `param` handle can be provided to give the attribute a specific value, else a + UnitAttr is added. Attributes can either be added with a single attribute in the + `param` payload broadcasted to all target operations, or the attributes will be + mapped 1:1 based on the order within the handles. + + Fails silently if the length of the parameter payload does not match the length of + the parameter handles. Does not consume the provided handles. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + StrAttr:$name, + Optional:$param); + let results = (outs); + + let assemblyFormat = + "$target $name attr-dict (`=` $param^)?" + "`:` type($target) (`,` type($param)^)?"; +} + def CastOp : TransformDialectOp<"cast", [TransformOpInterface, TransformEachOpTrait, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -301,6 +301,43 @@ return success(); } +//===----------------------------------------------------------------------===// +// AnnotateOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::AnnotateOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector targets = + llvm::to_vector(state.getPayloadOps(getTarget())); + + Attribute attr = UnitAttr::get(getContext()); + if (auto paramH = getParam()) { + ArrayRef params = state.getParams(paramH); + if (params.size() != 1) { + if (targets.size() != params.size()) { + return emitSilenceableError() + << "parameter and target have different payload lengths (" + << params.size() << " vs " << targets.size() << ")"; + } + for (auto &&[target, attr] : llvm::zip_equal(targets, params)) + target->setAttr(getName(), attr); + return DiagnosedSilenceableFailure::success(); + } + attr = params[0]; + } + for (auto target : targets) + target->setAttr(getName(), attr); + return DiagnosedSilenceableFailure::success(); +} + +void transform::AnnotateOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getTarget(), effects); + onlyReadsHandle(getParam(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1620,3 +1620,37 @@ // expected-remark @below {{2}} test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op } + + +// ----- + +// CHECK-LABEL: func @test_annotation() +// CHECK-NEXT: "test.annotate_me"() +// CHECK-SAME: broadcast_attr = 2 : i64 +// CHECK-SAME: new_attr = 1 : i32 +// CHECK-SAME: unit_attr +// CHECK-NEXT: "test.annotate_me"() +// CHECK-SAME: broadcast_attr = 2 : i64 +// CHECK-SAME: existing_attr = "test" +// CHECK-SAME: new_attr = 1 : i32 +// CHECK-SAME: unit_attr +// CHECK-NEXT: "test.annotate_me"() +// CHECK-SAME: broadcast_attr = 2 : i64 +// CHECK-SAME: new_attr = 1 : i32 +// CHECK-SAME: unit_attr +func.func @test_annotation() { + %0 = "test.annotate_me"() : () -> (i1) + %1 = "test.annotate_me"() {existing_attr = "test"} : () -> (i1) + %2 = "test.annotate_me"() {new_attr = 0} : () -> (i1) +} + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["test.annotate_me"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.test_produce_param_with_number_of_test_ops %0 : !transform.any_op + transform.annotate %0 "new_attr" = %1 : !transform.any_op, !transform.test_dialect_param + + %2 = transform.param.constant 2 -> !transform.param + transform.annotate %0 "broadcast_attr" = %2 : !transform.any_op, !transform.param + transform.annotate %0 "unit_attr" : !transform.any_op +}