diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -523,22 +523,23 @@ def MergeHandlesOp : TransformDialectOp<"merge_handles", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - SameOperandsAndResultType]> { + MatchOpInterface, SameOperandsAndResultType]> { let summary = "Merges handles into one pointing to the union of payload ops"; let description = [{ Creates a new Transform IR handle value that points to the same Payload IR - operations as the operand handles. The Payload IR operations are listed - in the same order as they are in the operand handles, grouped by operand - handle, e.g., all Payload IR operations associated with the first handle - come first, then all Payload IR operations associated with the second handle - and so on. If `deduplicate` is set, do not add the given Payload IR - operation more than once to the final list regardless of it coming from the - same or different handles. Consumes the operands and produces a new handle. + operations or parameters as the operand handles. The Payload IR operations + or parameters are listed in the same order as they are in the operand + handles, grouped by operand handle, e.g., all Payload IR associated with the + first handle come first, then all Payload IR associated with the second + handle and so on. If `deduplicate` is set, do not add the given Payload IR + operation or parameter more than once to the final list regardless of it + coming from the same or different handles. Consumes the operands and + produces a new handle. }]; - let arguments = (ins Variadic:$handles, + let arguments = (ins Variadic:$handles, UnitAttr:$deduplicate); - let results = (outs TransformHandleTypeInterface:$result); + let results = (outs Transform_AnyHandleOrParamType:$result); let assemblyFormat = "(`deduplicate` $deduplicate^)? $handles attr-dict `:` type($result)"; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1180,16 +1180,31 @@ DiagnosedSilenceableFailure transform::MergeHandlesOp::apply(transform::TransformResults &results, transform::TransformState &state) { - SmallVector operations; - for (Value operand : getHandles()) - llvm::append_range(operations, state.getPayloadOps(operand)); - if (!getDeduplicate()) { - results.set(llvm::cast(getResult()), operations); - return DiagnosedSilenceableFailure::success(); - } - SetVector uniqued(operations.begin(), operations.end()); - results.set(llvm::cast(getResult()), uniqued.getArrayRef()); + ValueRange handles = getHandles(); + if (isa(handles.front().getType())) { + SmallVector operations; + for (Value operand : handles) + llvm::append_range(operations, state.getPayloadOps(operand)); + if (!getDeduplicate()) { + results.set(llvm::cast(getResult()), operations); + return DiagnosedSilenceableFailure::success(); + } + + SetVector uniqued(operations.begin(), operations.end()); + results.set(llvm::cast(getResult()), uniqued.getArrayRef()); + } else { + SmallVector attrs; + for (Value attribute : handles) + llvm::append_range(attrs, state.getParams(attribute)); + if (!getDeduplicate()) { + results.setParams(cast(getResult()), attrs); + return DiagnosedSilenceableFailure::success(); + } + + SetVector uniqued(attrs.begin(), attrs.end()); + results.setParams(cast(getResult()), uniqued.getArrayRef()); + } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -223,6 +223,11 @@ transform.yield } + transform.named_sequence @print_dimension_size_match(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "matched sizes" : !transform.any_op + transform.yield + } + transform.named_sequence @match_dimension_capture(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { // Capture multiple dimension values. Suppress failures so we can print them anyway after the capture. %0:9 = transform.match.structured failures(suppress) %arg0 @@ -253,9 +258,25 @@ transform.yield %0#0 : !transform.any_op } + transform.named_sequence @match_dimension_sizes(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + %1 = transform.match.structured.dim %arg1[all] : (!transform.any_op) -> !transform.param + %c2 = transform.param.constant 2 : i64 -> !transform.param + %c3 = transform.param.constant 3 : i64 -> !transform.param + %c4 = transform.param.constant 4 : i64 -> !transform.param + %2 = transform.merge_handles %c2, %c3, %c4 : !transform.param + transform.match.param.cmpi eq %1, %2 : !transform.param + + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { ^bb0(%arg0: !transform.any_op): - transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op + %0 = transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op + %1 = transform.foreach_match in %0 @match_dimension_sizes -> @print_dimension_size_match : (!transform.any_op) -> !transform.any_op } func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } { @@ -269,6 +290,7 @@ // expected-remark @below {{dimensions except -1: 2 : i64, 3 : i64}} // expected-remark @below {{dimensions except 0, -2: 4 : i64}} // expected-remark @below {{dimensions 0, -3:}} + // expected-remark @below {{matched sizes}} linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"] diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1625,6 +1625,31 @@ // ----- +// Parameter deduplication happens by value + +module { + + transform.sequence failures(propagate) { + ^bb0(%0: !transform.any_op): + %1 = transform.param.constant 1 -> !transform.param + %2 = transform.param.constant 1 -> !transform.param + %3 = transform.param.constant 2 -> !transform.param + %4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param + // expected-remark @below {{1}} + test_print_number_of_associated_payload_ir_params %4 : !transform.param + %5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param + // expected-remark @below {{1}} + test_print_number_of_associated_payload_ir_params %5 : !transform.param + %6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param + // expected-remark @below {{2}} + test_print_number_of_associated_payload_ir_params %6 : !transform.param + %7 = transform.merge_handles %1, %1, %2, %3 : !transform.param + // expected-remark @below {{4}} + test_print_number_of_associated_payload_ir_params %7 : !transform.param + } +} +// ----- + // CHECK-LABEL: func @test_annotation() // CHECK-NEXT: "test.annotate_me"() // CHECK-SAME: broadcast_attr = 2 : i64 diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -448,6 +448,20 @@ transform::onlyReadsHandle(getHandle(), effects); } +DiagnosedSilenceableFailure +mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::apply( + transform::TransformResults &results, transform::TransformState &state) { + if (!getParam()) + emitRemark() << 0; + emitRemark() << llvm::range_size(state.getParams(getParam())); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getParam(), effects); +} + DiagnosedSilenceableFailure mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, transform::TransformState &state) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -344,6 +344,15 @@ let cppNamespace = "::mlir::test"; } +def TestPrintNumberOfAssociatedPayloadIRParams + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformParamTypeInterface:$param); + let assemblyFormat = "$param attr-dict `:` type($param)"; + let cppNamespace = "::mlir::test"; +} + def TestCopyPayloadOp : Op,