diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -523,22 +523,22 @@ def MergeHandlesOp : TransformDialectOp<"merge_handles", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - SameOperandsAndResultType]> { + MatchOpInterface, SameOperandsAndResultType]> { let summary = "Merges handles into one pointing to the union of payload ops"; let description = [{ Creates a new Transform IR handle value that points to the same Payload IR - operations as the operand handles. The Payload IR operations are listed - in the same order as they are in the operand handles, grouped by operand - handle, e.g., all Payload IR operations associated with the first handle - come first, then all Payload IR operations associated with the second handle - and so on. If `deduplicate` is set, do not add the given Payload IR - operation more than once to the final list regardless of it coming from the + operations/values/parameters as the operand handles. The Payload IR elements + are listed in the same order as they are in the operand handles, grouped by + operand handle, e.g., all Payload IR associated with the first handle comes + first, then all Payload IR associated with the second handle and so on. If + `deduplicate` is set, do not add the given Payload IR operation, value, or + parameter more than once to the final list regardless of it coming from the same or different handles. Consumes the operands and produces a new handle. }]; - let arguments = (ins Variadic:$handles, + let arguments = (ins Variadic:$handles, UnitAttr:$deduplicate); - let results = (outs TransformHandleTypeInterface:$result); + let results = (outs Transform_AnyHandleOrParamType:$result); let assemblyFormat = "(`deduplicate` $deduplicate^)? $handles attr-dict `:` type($result)"; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1180,16 +1180,48 @@ DiagnosedSilenceableFailure transform::MergeHandlesOp::apply(transform::TransformResults &results, transform::TransformState &state) { - SmallVector operations; - for (Value operand : getHandles()) - llvm::append_range(operations, state.getPayloadOps(operand)); + ValueRange handles = getHandles(); + if (isa(handles.front().getType())) { + SmallVector operations; + for (Value operand : handles) + llvm::append_range(operations, state.getPayloadOps(operand)); + if (!getDeduplicate()) { + results.set(llvm::cast(getResult()), operations); + return DiagnosedSilenceableFailure::success(); + } + + SetVector uniqued(operations.begin(), operations.end()); + results.set(llvm::cast(getResult()), uniqued.getArrayRef()); + return DiagnosedSilenceableFailure::success(); + } + + if (llvm::isa(handles.front().getType())) { + SmallVector attrs; + for (Value attribute : handles) + llvm::append_range(attrs, state.getParams(attribute)); + if (!getDeduplicate()) { + results.setParams(cast(getResult()), attrs); + return DiagnosedSilenceableFailure::success(); + } + + SetVector uniqued(attrs.begin(), attrs.end()); + results.setParams(cast(getResult()), uniqued.getArrayRef()); + return DiagnosedSilenceableFailure::success(); + } + + assert( + llvm::isa(handles.front().getType()) && + "expected value handle type"); + SmallVector payloadValues; + for (Value value : handles) + llvm::append_range(payloadValues, state.getPayloadValues(value)); if (!getDeduplicate()) { - results.set(llvm::cast(getResult()), operations); + results.setValues(cast(getResult()), payloadValues); return DiagnosedSilenceableFailure::success(); } - SetVector uniqued(operations.begin(), operations.end()); - results.set(llvm::cast(getResult()), uniqued.getArrayRef()); + SetVector uniqued(payloadValues.begin(), payloadValues.end()); + results.setValues(cast(getResult()), uniqued.getArrayRef()); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -223,6 +223,11 @@ transform.yield } + transform.named_sequence @print_dimension_size_match(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "matched sizes" : !transform.any_op + transform.yield + } + transform.named_sequence @match_dimension_capture(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { // Capture multiple dimension values. Suppress failures so we can print them anyway after the capture. %0:9 = transform.match.structured failures(suppress) %arg0 @@ -253,9 +258,25 @@ transform.yield %0#0 : !transform.any_op } + transform.named_sequence @match_dimension_sizes(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%arg1: !transform.any_op): + %1 = transform.match.structured.dim %arg1[all] : (!transform.any_op) -> !transform.param + %c2 = transform.param.constant 2 : i64 -> !transform.param + %c3 = transform.param.constant 3 : i64 -> !transform.param + %c4 = transform.param.constant 4 : i64 -> !transform.param + %2 = transform.merge_handles %c2, %c3, %c4 : !transform.param + transform.match.param.cmpi eq %1, %2 : !transform.param + + transform.match.structured.yield %arg1 : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { ^bb0(%arg0: !transform.any_op): - transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op + %0 = transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op + %1 = transform.foreach_match in %0 @match_dimension_sizes -> @print_dimension_size_match : (!transform.any_op) -> !transform.any_op } func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } { @@ -269,6 +290,7 @@ // expected-remark @below {{dimensions except -1: 2 : i64, 3 : i64}} // expected-remark @below {{dimensions except 0, -2: 4 : i64}} // expected-remark @below {{dimensions 0, -3:}} + // expected-remark @below {{matched sizes}} linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"] diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1622,9 +1622,64 @@ test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op } +// ----- + +// Parameter deduplication happens by value + +module { + + transform.sequence failures(propagate) { + ^bb0(%0: !transform.any_op): + %1 = transform.param.constant 1 -> !transform.param + %2 = transform.param.constant 1 -> !transform.param + %3 = transform.param.constant 2 -> !transform.param + %4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param + // expected-remark @below {{1}} + test_print_number_of_associated_payload_ir_params %4 : !transform.param + + %5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param + // expected-remark @below {{1}} + test_print_number_of_associated_payload_ir_params %5 : !transform.param + + %6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param + // expected-remark @below {{2}} + test_print_number_of_associated_payload_ir_params %6 : !transform.param + + %7 = transform.merge_handles %1, %1, %2, %3 : !transform.param + // expected-remark @below {{4}} + test_print_number_of_associated_payload_ir_params %7 : !transform.param + } +} // ----- +%0:3 = "test.get_two_results"() : () -> (i32, i32, f32) + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %1 = transform.structured.match ops{["test.get_two_results"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value + %3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value + + %4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value + // expected-remark @below {{1}} + test_print_number_of_associated_payload_ir_values %4 : !transform.any_value + + %5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value + // expected-remark @below {{2}} + test_print_number_of_associated_payload_ir_values %5 : !transform.any_value + + %6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value + %7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value + // expected-remark @below {{1}} + test_print_number_of_associated_payload_ir_values %6 : !transform.any_value + + %8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value + // expected-remark @below {{4}} + test_print_number_of_associated_payload_ir_values %8 : !transform.any_value +} +// ----- + // CHECK-LABEL: func @test_annotation() // CHECK-NEXT: "test.annotate_me"() // CHECK-SAME: broadcast_attr = 2 : i64 diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -448,6 +448,34 @@ transform::onlyReadsHandle(getHandle(), effects); } +DiagnosedSilenceableFailure +mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::apply( + transform::TransformResults &results, transform::TransformState &state) { + if (!getValueHandle()) + emitRemark() << 0; + emitRemark() << llvm::range_size(state.getPayloadValues(getValueHandle())); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getValueHandle(), effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::apply( + transform::TransformResults &results, transform::TransformState &state) { + if (!getParam()) + emitRemark() << 0; + emitRemark() << llvm::range_size(state.getParams(getParam())); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getParam(), effects); +} + DiagnosedSilenceableFailure mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, transform::TransformState &state) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -344,6 +344,24 @@ let cppNamespace = "::mlir::test"; } +def TestPrintNumberOfAssociatedPayloadIRValues + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformValueHandleTypeInterface:$value_handle); + let assemblyFormat = "$value_handle attr-dict `:` type($value_handle)"; + let cppNamespace = "::mlir::test"; +} + +def TestPrintNumberOfAssociatedPayloadIRParams + : Op, + DeclareOpInterfaceMethods]> { + let arguments = (ins TransformParamTypeInterface:$param); + let assemblyFormat = "$param attr-dict `:` type($param)"; + let cppNamespace = "::mlir::test"; +} + def TestCopyPayloadOp : Op,