diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -197,21 +197,21 @@ - ops: an optional StrArrayAttr specifying the concrete name of an op. Multiple names can be specified. Matched ops must have one of specified names. - - attribute: an optional Str specifying the name of an attribute that - matched ops must have. + - attribute: the matched op must have all specified attributes (with their + specified values). Note: Only ops that satisfy all specified constraints are matched. TODO: Extend with regions to allow a limited form of constraints. #### Return modes - + This op traverses the ops nested under `target` and returns the handles to all the operations that match the requirements. This op fails if the target is not a handle to exactly one operation. Otherwise it succeeds. - + This operation does not consume the target handle and produces new handles: it is a navigation op. }]; @@ -219,14 +219,14 @@ let arguments = (ins PDL_Operation:$target, OptionalAttr:$ops, OptionalAttr:$interface, - OptionalAttr:$attribute); + OptionalAttr:$op_attrs); // TODO: variadic results when needed. let results = (outs PDL_Operation:$results); let assemblyFormat = [{ - (`ops` `{` $ops^ `}`)? - (`interface` `{` $interface^ `}`)? - (`attribute` `{` $attribute^ `}`)? + (`ops` `{` $ops^ `}`)? + (`interface` `{` $interface^ `}`)? + (`attributes` $op_attrs^)? `in` $target attr-dict }]; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -467,8 +467,19 @@ return WalkResult::advance(); } - if (getAttribute().has_value() && !op->hasAttr(getAttribute().value())) - return WalkResult::advance(); + // Check if all specified attributes match. + if (getOpAttrs().has_value()) { + DictionaryAttr opAttrs = getOpAttrs().value(); + for (NamedAttribute attr : opAttrs) { + if (attr.getName() == getInterfaceAttrName() || + attr.getName() == getOpsAttrName()) + continue; + if (!op->hasAttr(attr.getName())) + return WalkResult::advance(); + if (op->getAttr(attr.getName()) != attr.getValue()) + return WalkResult::advance(); + } + } // All constraints are satisfied. res.push_back(op); diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -17,8 +17,45 @@ transform.test_print_remark_at_operand %match_name, "matched op name" transform.test_consume_operand %match_name - %match_attr = transform.structured.match ops{["arith.constant"]} attribute{"my_attr"} in %arg1 + %match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1 transform.test_print_remark_at_operand %match_attr, "matched attr name" transform.test_consume_operand %match_attr } } + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +func.func @match_complex_attribute(%arg0: tensor<12x128x32xf32>) + -> tensor<128x12x32xf32> { + %0 = linalg.init_tensor [128, 12, 32] : tensor<128x12x32xf32> + // expected-remark @below {{matched complex attr}} + %1 = linalg.generic {indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<12x128x32xf32>) + outs(%0 : tensor<128x12x32xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<128x12x32xf32> + return %1 : tensor<128x12x32xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %match_attr = transform.structured.match + ops{["linalg.generic"]} + attributes{iterator_types = ["parallel", "parallel", "parallel"]} + in %arg1 + transform.test_print_remark_at_operand %match_attr, "matched complex attr" + transform.test_consume_operand %match_attr + + %no_match = transform.structured.match + attributes{iterator_types = ["parallel", "parallel", "reduction"]} + in %arg1 + // expected-remark @below {{0}} + transform.test_print_number_of_associated_payload_ir_ops %no_match + } +} diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir @@ -45,8 +45,8 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): // Find the root and all producers. - %root = transform.structured.match attribute{"__root__"} in %arg1 - %producers = transform.structured.match attribute{"__producer__"} in %arg1 + %root = transform.structured.match attributes{"__root__"} in %arg1 + %producers = transform.structured.match attributes{"__producer__"} in %arg1 // Tile the root. %foreach_thread_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %root num_threads [10, 20] @@ -105,8 +105,8 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): // Find the root and all producers. - %root = transform.structured.match attribute{"__root__"} in %arg1 - %producers = transform.structured.match attribute{"__producer__"} in %arg1 + %root = transform.structured.match attributes{"__root__"} in %arg1 + %producers = transform.structured.match attributes{"__producer__"} in %arg1 %reversed_producers = transform.test_reverse_payload_ops %producers // Tile the root.