diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -11,9 +11,9 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" -include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/Transform/IR/TransformAttrs.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" @@ -88,12 +88,13 @@ ``` }]; - let arguments = (ins Optional:$scope); - let results = (outs Variadic:$results); + let arguments = (ins Optional:$scope); + let results = (outs Variadic:$results); let regions = (region VariadicRegion>:$alternatives); let assemblyFormat = - "($scope^)? (`->` type($results)^)? attr-dict-with-keyword regions"; + "($scope^ `:` type($scope))? (`->` type($results)^)? " + "attr-dict-with-keyword regions"; let hasVerifier = 1; } @@ -101,9 +102,8 @@ [TransformOpInterface, TransformEachOpTrait, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { - // TODO: temporarily fallback support for casting from PDL_Operation type. - let arguments = (ins AnyType:$input); - let results = (outs AnyType:$output); + let arguments = (ins TransformTypeInterface:$input); + let results = (outs TransformTypeInterface:$output); let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; let extraClassDeclaration = [{ @@ -143,10 +143,11 @@ merged and mapped to the same resulting handle. }]; - let arguments = (ins PDL_Operation:$target); - let results = (outs Variadic:$results); + let arguments = (ins TransformTypeInterface:$target); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$body); - let assemblyFormat = "$target (`->` type($results)^)? $body attr-dict"; + let assemblyFormat = + "$target `:` type($target) (`->` type($results)^)? $body attr-dict"; let hasVerifier = 1; let extraClassDeclaration = [{ @@ -182,9 +183,10 @@ on the further transformation applied to the handle produced here. }]; - let arguments = (ins PDL_Operation:$target); - let results = (outs PDL_Operation:$parent); - let assemblyFormat = "$target attr-dict"; + let arguments = (ins TransformTypeInterface:$target); + let results = (outs TransformTypeInterface:$parent); + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; } def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand", @@ -200,15 +202,17 @@ computational operations, which can be empty. }]; - let arguments = (ins PDL_Operation:$target, + let arguments = (ins TransformTypeInterface:$target, I64Attr:$operand_number); - let results = (outs PDL_Operation:$parent); - let assemblyFormat = "$target `[` $operand_number `]` attr-dict"; + let results = (outs TransformTypeInterface:$parent); + let assemblyFormat = "$target `[` $operand_number `]` attr-dict `:` " + "functional-type(operands, results)"; } def MergeHandlesOp : TransformDialectOp<"merge_handles", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + SameOperandsAndResultType]> { let summary = "Merges handles into one pointing to the union of payload ops"; let description = [{ Creates a new Transform IR handle value that points to the same Payload IR @@ -221,10 +225,10 @@ same or different handles. Consumes the operands and produces a new handle. }]; - let arguments = (ins Variadic:$handles, + let arguments = (ins Variadic:$handles, UnitAttr:$deduplicate); - let results = (outs PDL_Operation:$result); - let assemblyFormat = "($deduplicate^)? $handles attr-dict"; + let results = (outs TransformTypeInterface:$result); + let assemblyFormat = "($deduplicate^)? $handles attr-dict `:` type($result)"; let hasFolder = 1; } @@ -246,13 +250,12 @@ operations contained in the source `handle`. Otherwise it silently fails. }]; - let arguments = (ins PDL_Operation:$handle, + let arguments = (ins TransformTypeInterface:$handle, I64Attr:$num_result_handles); - let results = (outs Variadic:$results); + let results = (outs Variadic:$results); let assemblyFormat = [{ $handle `in` `[` $num_result_handles `]` - custom(type($results), ref($num_result_handles)) - attr-dict + attr-dict `:` functional-type(operands, results) }]; } @@ -278,17 +281,19 @@ }]; let arguments = (ins - Arg:$root, + Arg:$root, SymbolRefAttr:$pattern_name); let results = (outs - Res:$matched); + Res:$matched); - let assemblyFormat = "$pattern_name `in` $root attr-dict"; + let assemblyFormat = "$pattern_name `in` $root attr-dict `:` " + "functional-type(operands, results)"; } def ReplicateOp : TransformDialectOp<"replicate", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + AllTypesMatch<["handles", "replicated"]>]> { let summary = "Lists payload ops multiple times in the new handle"; let description = [{ Produces a new handle associated with a list of payload IR ops that is @@ -314,12 +319,11 @@ MergeHandlesOp can be used to construct arbitrary lists with repetitions. }]; - let arguments = (ins PDL_Operation:$pattern, - Variadic:$handles); - let results = (outs Variadic:$replicated); - let assemblyFormat = - "`num` `(` $pattern `)` $handles " - "custom(type($replicated), ref($handles)) attr-dict"; + let arguments = (ins TransformTypeInterface:$pattern, + Variadic:$handles); + let results = (outs Variadic:$replicated); + let assemblyFormat = "`num` `(` $pattern `)` $handles attr-dict `:` " + "type($pattern) `,` type($handles)"; } def SequenceOp : TransformDialectOp<"sequence", @@ -358,12 +362,13 @@ }]; let arguments = (ins FailurePropagationMode:$failure_propagation_mode, - Optional:$root); - let results = (outs Variadic:$results); + Optional:$root); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$body); let assemblyFormat = - "($root^)? `failures` `(` $failure_propagation_mode `)` attr-dict-with-keyword regions (`:` type($results)^)?"; + "($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` " + "$failure_propagation_mode `)` attr-dict-with-keyword regions"; let extraClassDeclaration = [{ /// Allow the dialect prefix to be omitted. @@ -414,10 +419,10 @@ }]; let arguments = (ins - Arg, "Root operation of the Payload IR", + Arg, "Root operation of the Payload IR", [TransformMappingRead]>:$root); let regions = (region SizedRegion<1>:$body); - let assemblyFormat = "($root^)? attr-dict-with-keyword regions"; + let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions"; let hasVerifier = 1; @@ -436,7 +441,7 @@ }]; let arguments = (ins - Arg, "Operation handles yielded back to the parent", + Arg, "Operation handles yielded back to the parent", [TransformMappingRead]>:$operands); let assemblyFormat = "operands attr-dict (`:` type($operands)^)?"; diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -42,6 +42,17 @@ } #endif // NDEBUG +namespace { +struct PDLOperationTypeTransformTypeInterfaceImpl + : public transform::TransformTypeInterface::ExternalModel< + PDLOperationTypeTransformTypeInterfaceImpl, pdl::OperationType> { + DiagnosedSilenceableFailure + checkPayload(Type type, Location loc, ArrayRef payload) const { + return DiagnosedSilenceableFailure::success(); + } +}; +} // namespace + void transform::TransformDialect::initialize() { // Using the checked versions to enable the same assertions as for the ops // from extensions. @@ -53,6 +64,9 @@ #define GET_TYPEDEF_LIST #include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc" >(); + + pdl::OperationType::attachInterface< + PDLOperationTypeTransformTypeInterfaceImpl>(*getContext()); } void transform::TransformDialect::mergeInPDLMatchHooks( diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" @@ -71,12 +70,11 @@ if (value.use_empty()) return success(); - if (auto iface = value.getType().dyn_cast()) { - DiagnosedSilenceableFailure result = - iface.checkPayload(value.getLoc(), targets); - if (failed(result.checkAndReport())) - return failure(); - } + auto iface = value.getType().cast(); + DiagnosedSilenceableFailure result = + iface.checkPayload(value.getLoc(), targets); + if (failed(result.checkAndReport())) + return failure(); // Setting new payload for the value without cleaning it first is a misuse of // the API, assert here. @@ -128,12 +126,11 @@ } } - if (auto iface = value.getType().dyn_cast()) { - DiagnosedSilenceableFailure result = - iface.checkPayload(value.getLoc(), updated); - if (failed(result.checkAndReport())) - return failure(); - } + auto iface = value.getType().cast(); + DiagnosedSilenceableFailure result = + iface.checkPayload(value.getLoc(), updated); + if (failed(result.checkAndReport())) + return failure(); std::swap(association, updated); return success(); @@ -369,10 +366,9 @@ Block *body = &bodyRegion->front(); if (body->getNumArguments() != 1 || - !body->getArgumentTypes()[0].isa()) { - return op->emitOpError() - << "expects the entry block to have one argument of type " - << pdl::OperationType::get(op->getContext()); + !body->getArgumentTypes()[0].isa()) { + return op->emitOpError() << "expects the entry block to have one argument " + "of type implementing TransformTypeInterface"; } if (auto *parent = diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -24,31 +25,6 @@ using namespace mlir; -/// Custom parser for ReplicateOp. -static ParseResult parsePDLOpTypedResults( - OpAsmParser &parser, SmallVectorImpl &types, - const SmallVectorImpl &handles) { - types.resize(handles.size(), pdl::OperationType::get(parser.getContext())); - return success(); -} - -/// Custom printer for ReplicateOp. -static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange, - ValueRange) {} - -/// Custom parser for SplitHandlesOp. -static ParseResult parseStaticNumPDLResults(OpAsmParser &parser, - SmallVectorImpl &types, - IntegerAttr numHandlesAttr) { - types.resize(numHandlesAttr.getInt(), - pdl::OperationType::get(parser.getContext())); - return success(); -} - -/// Custom printer for SplitHandlesOp. -static void printStaticNumPDLResults(OpAsmPrinter &, Operation *, TypeRange, - IntegerAttr) {} - #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" @@ -269,13 +245,6 @@ LogicalResult transform::AlternativesOp::verify() { for (Region &alternative : getAlternatives()) { Block &block = alternative.front(); - if (block.getNumArguments() != 1 || - !block.getArgument(0).getType().isa()) { - return emitOpError() - << "expects region blocks to have one operand of type " - << pdl::OperationType::get(getContext()); - } - Operation *terminator = block.getTerminator(); if (terminator->getOperands().getTypes() != getResults().getTypes()) { InFlightDiagnostic diag = emitOpError() @@ -403,8 +372,9 @@ return emitOpError() << "expects the same number of results as the " "terminator has operands"; for (Value v : yieldOp.getOperands()) - if (!v.getType().isa()) - return yieldOp->emitOpError("expects only PDL_Operation operands"); + if (!v.getType().isa()) + return yieldOp->emitOpError( + "expects operands to have types implementing TransformTypeInterface"); return success(); } diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -5,7 +5,6 @@ try: from ..ir import * from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values - from ..dialects import pdl except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -21,9 +20,9 @@ class GetClosestIsolatedParentOp: - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), loc=loc, ip=ip) @@ -38,7 +37,7 @@ loc=None, ip=None): super().__init__( - pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles], + [_get_op_result_or_value(h) for h in handles], deduplicate=deduplicate, loc=loc, ip=ip) @@ -47,13 +46,14 @@ class PDLMatchOp: def __init__(self, + result_type: Type, target: Union[Operation, Value], pattern_name: Union[Attribute, str], *, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), _get_symbol_ref_attr(pattern_name), loc=loc, @@ -69,7 +69,7 @@ loc=None, ip=None): super().__init__( - [pdl.OperationType.get()] * len(handles), + [_get_op_result_or_value(h).type for h in handles], _get_op_result_or_value(pattern), [_get_op_result_or_value(h) for h in handles], loc=loc, @@ -78,24 +78,11 @@ class SequenceOp: - @overload - def __init__(self, failure_propagation_mode, - resultsOrRoot: Sequence[Type], - optionalRoot: Optional[Union[Operation, Value]]): - ... - - @overload - def __init__(self, failure_propagation_mode, - resultsOrRoot: Optional[Union[Operation, - Value]], optionalRoot: NoneType): - ... - - def __init__(self, failure_propagation_mode, resultsOrRoot=None, optionalRoot=None): - results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else [] - root = ( - resultsOrRoot - if not isinstance(resultsOrRoot, Sequence) else optionalRoot) - root = _get_op_result_or_value(root) if root else None + def __init__(self, failure_propagation_mode, results: Sequence[Type], + target: Union[Operation, Value, Type]): + root = _get_op_result_or_value(target) if isinstance( + target, (Operation, Value)) else None + root_type = root.type if not isinstance(target, Type) else target if not isinstance(failure_propagation_mode, Attribute): failure_propagation_mode_attr = IntegerAttr.get( IntegerType.get_signless(32), failure_propagation_mode._as_int()) @@ -104,7 +91,7 @@ super().__init__(results_=results, failure_propagation_mode=failure_propagation_mode_attr, root=root) - self.regions[0].blocks.append(pdl.OperationType.get()) + self.regions[0].blocks.append(root_type) @property def body(self) -> Block: @@ -118,15 +105,18 @@ class WithPDLPatternsOp: def __init__(self, - target: Optional[Union[Operation, Value]] = None, + target: Union[Operation, Value, Type], *, loc=None, ip=None): + root = _get_op_result_or_value(target) if not isinstance(target, + Type) else None + root_type = target if isinstance(target, Type) else root.type super().__init__( - root=_get_op_result_or_value(target) if target else None, + root=root, loc=loc, ip=ip) - self.regions[0].blocks.append(pdl.OperationType.get()) + self.regions[0].blocks.append(root_type) @property def body(self) -> Block: diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -4,7 +4,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 transform.bufferization.one_shot_bufferize %0 @@ -36,7 +36,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 transform.bufferization.one_shot_bufferize %0 @@ -62,7 +62,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 // expected-error @+1 {{bufferization failed}} @@ -82,7 +82,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): // %arg1 is the module transform.bufferization.one_shot_bufferize %arg1 diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -46,7 +46,7 @@ } transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 // expected-error @below {{Trying to launch a GPU kernel with gridDim = (1, 1, 1) blockDim = (1200, 9, 1). It is larger than the limits.}} @@ -90,7 +90,7 @@ } transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 // expected-error @below {{The requested GPU threads are fewer than the number of loop trip counts. Try to tile scf.foreach_thread before mapping or set small blockDim.}} @@ -119,7 +119,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 // expected-error @below {{unsupported dynamic blockdim size}} @@ -147,7 +147,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 // expected-error @below {{scf.foreach_thread with rank > 3 does not lower to gpu.thread_id}} @@ -170,7 +170,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 %foreach, %tiled = transform.structured.tile_to_foreach_thread_op %matmul num_threads [10, 20, 30] diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -17,22 +17,22 @@ // CHECK: memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]]] // CHECK: memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]]] %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) - threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) { scf.foreach_thread (%i, %j) in (%c7, %c9) { - %4 = memref.load %x[%i, %j] : !type + %4 = memref.load %x[%i, %j] : !type %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [0, 1, 2]} + } {thread_dim_mapping = [0, 1, 2]} gpu.terminator - } + } return %y : !type } transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 transform.gpu.map_foreach_to_blocks %funcop { blockDim = [12, 9, 1]} @@ -69,16 +69,16 @@ // CHECK: memref.load %[[ARGT]][%[[TIDX]]] // CHECK: gpu.barrier %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) - threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) { scf.foreach_thread (%i, %j) in (%c7, %c9) { - %4 = memref.load %x[%i, %j] : !type + %4 = memref.load %x[%i, %j] : !type %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [1, 0, 2]} + } {thread_dim_mapping = [1, 0, 2]} scf.foreach_thread (%i) in (%c12) { - %7 = memref.load %t[%i] : !type1d + %7 = memref.load %t[%i] : !type1d %8 = arith.addf %alpha, %7 : f32 memref.store %8, %t[%i] : !type1d } {thread_dim_mapping = [0, 1, 2]} @@ -89,7 +89,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1] } @@ -104,9 +104,9 @@ // CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<32x64x4x32xf32> // CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<32x64x4x32xf32> func.func @saxpy4d(%x: !type4d, %y: !type4d, %alpha : f32) -> !type4d { - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index // CHECK: %[[C32:.*]] = arith.constant 32 : index // CHECK: %[[C64:.*]] = arith.constant 64 : index // CHECK: %[[C4:.*]] = arith.constant 4 : index @@ -120,7 +120,7 @@ // CHECK: memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]], %[[TIDY]], %[[TIDX]]] scf.foreach_thread (%i, %j) in (%c32, %c64) { scf.foreach_thread (%k, %l) in (%c4, %c32) { - %4 = memref.load %x[%i, %j, %k, %l] : !type4d + %4 = memref.load %x[%i, %j, %k, %l] : !type4d %5 = memref.load %y[%i, %j, %k, %l] : !type4d %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j, %k, %l] : !type4d @@ -131,7 +131,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %funcop = transform.structured.match ops{["func.func"]} in %arg0 %gpuLaunch = transform.gpu.map_foreach_to_blocks %funcop { generate_gpu_launch } @@ -168,7 +168,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false } diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir --- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir +++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir @@ -3,7 +3,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): // This implements a 2D multisize tiling with target sizes [3, 10]. - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} @@ -11,8 +11,8 @@ %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } %3:2 = transform.structured.tile %2#0 [%1#0] %4:2 = transform.structured.tile %2#1 [%1#1] - %5 = merge_handles %3#0, %4#0 - %tt:3 = replicate num(%5) %t#0, %t#1, %t#2 + %5 = merge_handles %3#0, %4#0 : !pdl.operation + %tt:3 = replicate num(%5) %t#0, %t#1, %t#2 : !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } transform.structured.tile %6#0 [0, %tt#0] transform.structured.tile %6#1 [0, %tt#1] diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -68,7 +68,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1 = transform.structured.promote %0 { use_alloca } @@ -141,7 +141,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1 = transform.structured.promote %0 @@ -191,7 +191,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match interface{LinalgOp} in %arg1 %1 = transform.structured.promote %0 diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir --- a/mlir/test/Dialect/Linalg/promotion_options.mlir +++ b/mlir/test/Dialect/Linalg/promotion_options.mlir @@ -33,7 +33,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -34,7 +34,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims [1, 0]) @@ -78,7 +78,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 21] @@ -99,13 +99,13 @@ // CHECK-SAME: %[[A:[0-9a-z]+]]: tensor // CHECK-SAME: %[[B:[0-9a-z]+]]: tensor // CHECK-SAME: %[[C:[0-9a-z]+]]: tensor -func.func @matmul_tile_size_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { +func.func @matmul_tile_size_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { // CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 : - // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : + // CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : // CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]] // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]] // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) - // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]] + // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]] // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]] // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]]) // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]]) @@ -122,8 +122,8 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { - ^bb1(%arg1: !pdl.operation): + transform.sequence %arg0 : !pdl.operation failures(propagate) { + ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20] } @@ -145,7 +145,7 @@ // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : // CHECK-DAG: %[[c15:.+]] = arith.constant 15 : // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) - // CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]]) + // CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]]) // CHECK-NOT: affine.max // CHECK-NOT: affine.min // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]]) @@ -163,7 +163,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 21] @@ -188,7 +188,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0]) @@ -242,7 +242,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %sz = transform.structured.match ops{["test.dummy"]} in %arg1 @@ -262,7 +262,7 @@ // CHECK-SAME: %[[IN2:[0-9a-z]+]]: tensor<100xf32> // CHECK-SAME: %[[ORGOUT1:[0-9a-z]+]]: tensor<100xf32> // CHECK-SAME: %[[ORGOUT2:[0-9a-z]+]]: tensor<100xf32> - func.func @tile_output_multi_1d_static(%IN1: tensor<100xf32>, %IN2: tensor<100xf32>, + func.func @tile_output_multi_1d_static(%IN1: tensor<100xf32>, %IN2: tensor<100xf32>, %OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>) -> (tensor<100xf32>, tensor<100xf32>) { // CHECK-DAG: %[[c0:.+]] = arith.constant 7 : @@ -288,7 +288,7 @@ affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] } ins(%IN1, %IN2 : tensor<100xf32>, tensor<100xf32>) - outs(%OUT1, %OUT2 : tensor<100xf32>, tensor<100xf32>) + outs(%OUT1, %OUT2 : tensor<100xf32>, tensor<100xf32>) { ^bb0(%a1: f32, %a2: f32, %a3: f32, %a4: f32): %1 = arith.addf %a1, %a3 : f32 @@ -300,7 +300,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %foreach_thread, %tiled_generic = transform.structured.tile_to_foreach_thread_op %0 num_threads [7] @@ -321,7 +321,7 @@ // CHECK-SAME: %[[IN3:[0-9a-z]+]]: tensor<300xf32> // CHECK-SAME: %[[ORGOUT1:[0-9a-z]+]]: tensor<300x100xf32> // CHECK-SAME: %[[ORGOUT2:[0-9a-z]+]]: tensor<300xf32> - func.func @tile_output_multi_1d2d_static(%IN1: tensor<100xf32>, %IN2: tensor<100x300xf32>, %IN3: tensor<300xf32>, + func.func @tile_output_multi_1d2d_static(%IN1: tensor<100xf32>, %IN2: tensor<100x300xf32>, %IN3: tensor<300xf32>, %OUT1: tensor<300x100xf32>, %OUT2: tensor<300xf32>) -> (tensor<300x100xf32>, tensor<300xf32>) { // CHECK-DAG: %[[c0:.+]] = arith.constant 4 : @@ -336,11 +336,11 @@ // CHECK-NEXT: tensor.parallel_insert_slice %[[RES1]]#0 into %[[OUT1]][%[[LB]], 0] [75, 100] // CHECK-NEXT: tensor.parallel_insert_slice %[[RES1]]#1 into %[[OUT2]][%[[LB]]] [75] %res2, %res3 = linalg.generic { - indexing_maps = [affine_map<(d0,d1) -> (d1)>, + indexing_maps = [affine_map<(d0,d1) -> (d1)>, affine_map<(d0,d1) -> (d1,d0)>, - affine_map<(d0,d1) -> (d0)>, + affine_map<(d0,d1) -> (d0)>, affine_map<(d0,d1) -> (d0,d1)>, - affine_map<(d0,d1) -> (d0)> + affine_map<(d0,d1) -> (d0)> ], iterator_types = ["parallel", "parallel"] } ins(%IN1, %IN2, %IN3 : tensor<100xf32>, tensor<100x300xf32>, tensor<300xf32>) @@ -351,13 +351,13 @@ %3 = arith.addf %i3, %2 : f32 linalg.yield %3, %i3 : f32, f32 } -> (tensor<300x100xf32>, tensor<300xf32>) - + return %res2, %res3 : tensor<300x100xf32>, tensor<300xf32> } transform.with_pdl_patterns { ^bb0(%IN_MAT1: !pdl.operation): - transform.sequence %IN_MAT1 failures(propagate) { + transform.sequence %IN_MAT1 : !pdl.operation failures(propagate) { ^bb1(%IN_MAT2: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %IN_MAT2 %foreach_thread, %tiled_generic = transform.structured.tile_to_foreach_thread_op %0 num_threads [4] diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -58,7 +58,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match interface{LinalgOp} in %arg1 %1 = transform.structured.decompose %0 diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -43,7 +43,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 @@ -89,7 +89,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -17,7 +17,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} @@ -47,7 +47,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} @@ -95,7 +95,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]} diff --git a/mlir/test/Dialect/Linalg/transform-op-generalize.mlir b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-generalize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir @@ -12,7 +12,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 %1 = transform.structured.generalize %0 diff --git a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir --- a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir @@ -20,7 +20,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 transform.structured.interchange %0 { iterator_interchange = [1, 0]} @@ -37,7 +37,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 // expected-error @below {{transform applied to the wrong op kind}} diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -11,7 +11,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %match_name = transform.structured.match ops{["arith.constant"]} in %arg1 transform.test_print_remark_at_operand %match_name, "matched op name" @@ -34,7 +34,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %match_name = transform.structured.match ops{["arith.constant"]} filter_result_type = f32 in %arg1 @@ -63,7 +63,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %match_attr = transform.structured.match ops{["linalg.generic"]} diff --git a/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir b/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir --- a/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir @@ -4,7 +4,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 } @@ -31,7 +31,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 } diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -33,7 +33,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} @@ -52,7 +52,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 // expected-error @below {{op expects a padding value of type 'f32', got 0 : i32}} @@ -72,7 +72,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 // expected-error @below {{expects a padding that parses to 'f32', got "foo"}} @@ -93,7 +93,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(suppress) { + transform.sequence %arg0 : !pdl.operation failures(suppress) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 // This error is silenceable and is not reported by this transform diff --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir @@ -12,7 +12,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1, %loops = transform.structured.tile %0 [10, 0, 0] diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir @@ -20,7 +20,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1:4 = transform.structured.split_reduction %0 diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -19,7 +19,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2} diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir @@ -2,7 +2,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1:2 = transform.structured.split %0 after 42 { dimension = 0 } @@ -76,7 +76,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1 = transform.structured.match ops{["func.call"]} in %arg1 @@ -127,7 +127,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1:2 = transform.structured.split %0 after 4 { dimension = 0} @@ -195,7 +195,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1 = transform.structured.match ops{["func.call"]} in %arg1 @@ -224,7 +224,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1 = transform.structured.match ops{["func.call"]} in %arg1 @@ -258,7 +258,7 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.return"]} in %arg1 // expected-error @below {{only applies to structured ops}} @@ -282,7 +282,7 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 // expected-error @below {{dimension 1 does not exist in target op}} diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -2,7 +2,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1, %loops:3 = transform.structured.tile %0 [4, 4, 4] @@ -41,7 +41,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1 = transform.structured.match ops{["func.call"]} in %arg1 diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -26,10 +26,10 @@ rewrite %0 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 } } @@ -75,10 +75,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 } } @@ -126,10 +126,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 {vectorize_padding} } } @@ -146,7 +146,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 // expected-error @below {{op requires isolated-from-above targets}} diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -144,7 +144,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]} diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir --- a/mlir/test/Dialect/Linalg/transform-promotion.mlir +++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir @@ -60,7 +60,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1 = transform.structured.promote %0 { operands_to_promote = [0, 1, 2], use_full_tiles_by_default } @@ -123,7 +123,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1 = transform.structured.promote %0 { operands_to_promote = [0], use_full_tiles_by_default } @@ -156,7 +156,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 %1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32} @@ -190,7 +190,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 %1 = transform.structured.promote %0 { operands_to_promote = [1], use_full_tile_buffers = [false, true], alignment = 32} diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir @@ -42,7 +42,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): // Find the root and all producers. %root = transform.structured.match attributes{"__root__"} in %arg1 @@ -102,7 +102,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): // Find the root and all producers. %root = transform.structured.match attributes{"__root__"} in %arg1 diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -14,10 +14,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.dot"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } } } @@ -36,10 +36,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } } } @@ -57,10 +57,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } } } @@ -79,10 +79,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } } } @@ -122,10 +122,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } } } @@ -165,10 +165,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } } } @@ -195,10 +195,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } } } @@ -238,10 +238,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } } } @@ -261,10 +261,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } } } @@ -288,11 +288,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } @@ -315,11 +315,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } @@ -335,11 +335,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } @@ -356,11 +356,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } @@ -376,11 +376,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["memref.copy"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } @@ -399,11 +399,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["memref.copy"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } // ----- @@ -429,11 +429,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } @@ -461,11 +461,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } @@ -549,10 +549,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns } } } @@ -643,10 +643,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns } } } @@ -690,10 +690,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns } } } @@ -736,10 +736,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns } } } @@ -771,10 +771,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } } } @@ -803,10 +803,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { vectorize_padding } } } @@ -835,10 +835,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { vectorize_padding } } } @@ -875,10 +875,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { vectorize_padding } } } @@ -909,10 +909,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { vectorize_padding } } } @@ -945,10 +945,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 - %4 = get_closest_isolated_parent %3 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation %5 = transform.structured.vectorize %4 { vectorize_padding } } } @@ -985,10 +985,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 - %4 = get_closest_isolated_parent %3 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation %5 = transform.structured.vectorize %4 { vectorize_padding } } } @@ -1022,10 +1022,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 - %4 = get_closest_isolated_parent %3 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation %5 = transform.structured.vectorize %4 { vectorize_padding } } } @@ -1053,11 +1053,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 - %4 = get_closest_isolated_parent %3 - %5 = transform.structured.vectorize %4 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation + %5 = transform.structured.vectorize %4 } } @@ -1093,10 +1093,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 - %4 = get_closest_isolated_parent %3 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation %5 = transform.structured.vectorize %4 { vectorize_padding } } } @@ -1131,11 +1131,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 - %5 = transform.structured.vectorize %4 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation + %5 = transform.structured.vectorize %4 } } @@ -1179,10 +1179,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation %5 = transform.structured.vectorize %4 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } } } @@ -1212,10 +1212,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation %5 = transform.structured.vectorize %4 { vectorize_padding } } } @@ -1246,11 +1246,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 - %5 = transform.structured.vectorize %4 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation + %5 = transform.structured.vectorize %4 } } @@ -1279,11 +1279,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 - %5 = transform.structured.vectorize %4 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation + %5 = transform.structured.vectorize %4 } } @@ -1312,11 +1312,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 - %5 = transform.structured.vectorize %4 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation + %5 = transform.structured.vectorize %4 } } @@ -1345,11 +1345,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 - %5 = transform.structured.vectorize %4 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation + %5 = transform.structured.vectorize %4 } } @@ -1378,11 +1378,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 - %5 = transform.structured.vectorize %4 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation + %5 = transform.structured.vectorize %4 } } @@ -1415,11 +1415,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 - %5 = transform.structured.vectorize %4 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation + %5 = transform.structured.vectorize %4 } } @@ -1456,15 +1456,15 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 - + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %4 = get_closest_isolated_parent %3 - %5 = transform.structured.vectorize %4 + %4 = get_closest_isolated_parent %3 : (!pdl.operation) -> !pdl.operation + %5 = transform.structured.vectorize %4 } } @@ -1506,11 +1506,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } @@ -1518,7 +1518,7 @@ // ----- // This test checks that vectorization does not occur when an input indexing map -// is not a projected permutation. In the future, this can be converted to a +// is not a projected permutation. In the future, this can be converted to a // positive test when support is added. // CHECK-LABEL: func @not_projected_permutation @@ -1540,11 +1540,11 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.structured.vectorize %1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 } } @@ -1582,10 +1582,10 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1 = get_closest_isolated_parent %0 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } } } diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -17,7 +17,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 // CHECK: = transform.loop.get_parent_for @@ -34,13 +34,13 @@ func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { // expected-note @below {{target op}} - arith.addi %arg0, %arg1 : index + arith.addi %arg0, %arg1 : index return } transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 // expected-error @below {{could not find an 'scf.for' parent}} @@ -82,7 +82,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 %1 = transform.loop.get_parent_for %0 @@ -111,7 +111,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["scf.while"]} in %arg1 // expected-error @below {{failed to outline}} @@ -142,7 +142,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 %1 = transform.loop.get_parent_for %0 @@ -178,7 +178,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addf"]} in %arg1 %1 = transform.loop.get_parent_for %0 @@ -205,7 +205,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 %1 = transform.loop.get_parent_for %0 diff --git a/mlir/test/Dialect/Transform/check-use-after-free.mlir b/mlir/test/Dialect/Transform/check-use-after-free.mlir --- a/mlir/test/Dialect/Transform/check-use-after-free.mlir +++ b/mlir/test/Dialect/Transform/check-use-after-free.mlir @@ -17,7 +17,7 @@ "transform.test_branching_transform_op_terminator"()[^bb3] : () -> () ^bb3: // expected-warning @below {{operand #0 may be used after free}} - transform.sequence %0 failures(propagate) { + transform.sequence %0 : !pdl.operation failures(propagate) { ^bb0(%arg0: !pdl.operation): } "transform.test_branching_transform_op_terminator"() : () -> () @@ -46,7 +46,7 @@ "transform.test_branching_transform_op_terminator"() : () -> () } // expected-warning @below {{operand #0 may be used after free}} - transform.sequence %0 failures(propagate) { + transform.sequence %0 : !pdl.operation failures(propagate) { ^bb0(%arg0: !pdl.operation): } return @@ -58,26 +58,26 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-note @below {{allocated here}} - %0 = transform.sequence %arg0 failures(propagate) attributes { ord = 1 } { + %0 = transform.sequence %arg0 : !pdl.operation -> !pdl.operation failures(propagate) attributes { ord = 1 } { ^bb1(%arg1: !pdl.operation): yield %arg1 : !pdl.operation - } : !pdl.operation - transform.sequence %0 failures(propagate) attributes { ord = 2 } { + } + transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 2 } { ^bb2(%arg2: !pdl.operation): } - transform.sequence %0 failures(propagate) attributes { ord = 3 } { + transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 3 } { ^bb3(%arg3: !pdl.operation): } // `transform.sequence` has recursive side effects so it has the same "free" // as the child op it contains. // expected-note @below {{freed here}} - transform.sequence %0 failures(propagate) attributes { ord = 4 } { + transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 4 } { ^bb4(%arg4: !pdl.operation): test_consume_operand_if_matches_param_or_fail %0[42] } // expected-warning @below {{operand #0 may be used after free}} - transform.sequence %0 failures(propagate) attributes { ord = 5 } { + transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 5 } { ^bb3(%arg3: !pdl.operation): } } @@ -90,21 +90,21 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-note @below {{allocated here}} - %0 = transform.sequence %arg0 failures(propagate) attributes { ord = 1 } { + %0 = transform.sequence %arg0 : !pdl.operation -> !pdl.operation failures(propagate) attributes { ord = 1 } { ^bb1(%arg1: !pdl.operation): yield %arg1 : !pdl.operation - } : !pdl.operation - transform.sequence %0 failures(propagate) attributes { ord = 2 } { + } + transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 2 } { ^bb2(%arg2: !pdl.operation): } - transform.sequence %0 failures(propagate) attributes { ord = 3 } { + transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 3 } { ^bb3(%arg3: !pdl.operation): } // expected-note @below {{freed here}} test_consume_operand_if_matches_param_or_fail %0[42] // expected-warning @below {{operand #0 may be used after free}} - transform.sequence %0 failures(propagate) attributes { ord = 5 } { + transform.sequence %0 : !pdl.operation failures(propagate) attributes { ord = 5 } { ^bb3(%arg3: !pdl.operation): } } @@ -127,7 +127,7 @@ "transform.test_branching_transform_op_terminator"()[^bb1] : () -> () ^bb1: // expected-warning @below {{operand #0 may be used after free}} - transform.sequence %0 failures(propagate) { + transform.sequence %0 : !pdl.operation failures(propagate) { ^bb0(%arg0: !pdl.operation): } // expected-warning @below {{operand #0 may be used after free}} diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -15,11 +15,11 @@ rewrite %2 with "transform.dialect" } - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): // expected-note @below {{handle to invalidated ops}} - %0 = pdl_match @return in %arg1 - %1 = get_closest_isolated_parent %0 + %0 = pdl_match @return in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation // expected-note @below {{invalidated by this transform op that consumes its operand #0}} test_consume_operand %1 // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} @@ -50,11 +50,11 @@ rewrite %2 with "transform.dialect" } - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @func in %arg1 - %1 = pdl_match @return in %arg1 - %2 = replicate num(%0) %1 + %0 = pdl_match @func in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = pdl_match @return in %arg1 : (!pdl.operation) -> !pdl.operation + %2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}} test_consume_operand %2 test_print_remark_at_operand %0, "remark" diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics -// expected-error @below {{expects the entry block to have one argument of type '!pdl.operation'}} +// expected-error @below {{expects the entry block to have one argument of type implementing TransformTypeInterface}} transform.sequence failures(propagate) { } @@ -26,12 +26,12 @@ // ----- -// expected-error @below {{expects the types of the terminator operands to match the types of the resul}} -%0 = transform.sequence failures(propagate) { +// expected-error @below {{expects the types of the terminator operands to match the types of the result}} +%0 = transform.sequence -> !pdl.operation failures(propagate) { ^bb0(%arg0: !pdl.operation): // expected-note @below {{terminator}} transform.yield -} : !pdl.operation +} // ----- @@ -74,7 +74,7 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): // expected-error @below {{op cannot be nested}} - transform.with_pdl_patterns %arg0 { + transform.with_pdl_patterns %arg0 : !pdl.operation { ^bb1(%arg1: !pdl.operation): } } @@ -115,7 +115,7 @@ // expected-note @below {{used here as operand #0}} test_consume_operand_if_matches_param_or_fail %0[42] // expected-note @below {{used here as operand #0}} - transform.sequence %0 failures(propagate) { + transform.sequence %0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): test_consume_operand_if_matches_param_or_fail %arg1[42] } @@ -129,7 +129,7 @@ %0 = test_produce_param_or_forward_operand 42 // expected-note @below {{used here as operand #0}} test_consume_operand_if_matches_param_or_fail %0[42] - transform.sequence %0 failures(propagate) { + transform.sequence %0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): // expected-note @below {{used here as operand #0}} test_consume_operand_if_matches_param_or_fail %0[42] @@ -145,9 +145,9 @@ // expected-note @below {{used here as operand #0}} test_consume_operand_if_matches_param_or_fail %0[42] // expected-note @below {{used here as operand #0}} - transform.sequence %0 failures(propagate) { + transform.sequence %0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - transform.sequence %arg1 failures(propagate) { + transform.sequence %arg1 : !pdl.operation failures(propagate) { ^bb2(%arg2: !pdl.operation): test_consume_operand_if_matches_param_or_fail %arg2[42] } @@ -167,7 +167,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): // expected-error @below {{expects terminator operands to have the same type as results of the operation}} - %2 = transform.alternatives %arg1 -> !pdl.operation { + %2 = transform.alternatives %arg1 : !pdl.operation -> !pdl.operation { ^bb2(%arg2: !pdl.operation): transform.yield %arg2 : !pdl.operation }, { @@ -179,7 +179,7 @@ // ----- -// expected-error @below {{expects the entry block to have one argument of type '!pdl.operation'}} +// expected-error @below {{expects the entry block to have one argument of type implementing TransformTypeInterface}} transform.alternatives { ^bb0: transform.yield @@ -192,7 +192,7 @@ // expected-error @below {{result #0 has more than one potential consumer}} %0 = test_produce_param_or_forward_operand 42 // expected-note @below {{used here as operand #0}} - transform.foreach %0 { + transform.foreach %0 : !pdl.operation { ^bb1(%arg1: !pdl.operation): transform.test_consume_operand %arg1 } diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -4,9 +4,9 @@ // CHECK: ^{{.+}}(%{{.+}}: !pdl.operation): transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - // CHECK: sequence %{{.+}} + // CHECK: sequence %{{.+}} : !pdl.operation // CHECK: ^{{.+}}(%{{.+}}: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): } } @@ -15,8 +15,8 @@ // CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation): transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - // CHECK: sequence %[[ARG]] - sequence %arg0 failures(propagate) { + // CHECK: sequence %[[ARG]] : !pdl.operation + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): } } @@ -25,27 +25,27 @@ // CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation): transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - // CHECK: with_pdl_patterns %[[ARG]] - with_pdl_patterns %arg0 { + // CHECK: with_pdl_patterns %[[ARG]] : !pdl.operation + with_pdl_patterns %arg0 : !pdl.operation { ^bb1(%arg1: !pdl.operation): } } // Using the same value multiple times without consuming it is fine. // CHECK: transform.sequence -// CHECK: %[[V:.+]] = sequence +// CHECK: %[[V:.+]] = sequence %{{.*}} : !pdl.operation -> !pdl.operation // CHECK: sequence %[[V]] // CHECK: sequence %[[V]] transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - %0 = transform.sequence %arg0 failures(propagate) { + %0 = transform.sequence %arg0 : !pdl.operation -> !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): yield %arg1 : !pdl.operation - } : !pdl.operation - transform.sequence %0 failures(propagate) { + } + transform.sequence %0 : !pdl.operation failures(propagate) { ^bb2(%arg2: !pdl.operation): } - transform.sequence %0 failures(propagate) { + transform.sequence %0 : !pdl.operation failures(propagate) { ^bb3(%arg3: !pdl.operation): } } @@ -54,7 +54,7 @@ // CHECK: foreach transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - transform.foreach %arg0 { + transform.foreach %arg0 : !pdl.operation { ^bb1(%arg1: !pdl.operation): } } diff --git a/mlir/test/Dialect/Transform/selective-targeting.mlir b/mlir/test/Dialect/Transform/selective-targeting.mlir --- a/mlir/test/Dialect/Transform/selective-targeting.mlir +++ b/mlir/test/Dialect/Transform/selective-targeting.mlir @@ -74,12 +74,12 @@ rewrite %0 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @pdl_target_attrA in %arg1 + %0 = pdl_match @pdl_target_attrA in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.tile %0 [4, 4, 4] - %1 = pdl_match @pdl_target_attrC in %arg1 - %2 = transform.get_closest_isolated_parent %1 + %1 = pdl_match @pdl_target_attrC in %arg1 : (!pdl.operation) -> !pdl.operation + %2 = transform.get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %2 } } @@ -121,10 +121,10 @@ rewrite %0 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @pdl_target in %arg1 - %1 = get_closest_isolated_parent %0 + %0 = pdl_match @pdl_target in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %1 } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -29,7 +29,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): // expected-remark @below {{applying transformation "a"}} test_transform_op "a" @@ -49,7 +49,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): %0 = test_produce_param_or_forward_operand 42 - sequence %0 failures(propagate) { + sequence %0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): // expected-remark @below {{succeeded}} test_consume_operand_if_matches_param_or_fail %arg1[42] @@ -60,11 +60,11 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - %0 = sequence %arg0 failures(propagate) { + %0 = sequence %arg0 : !pdl.operation -> !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %1 = test_produce_param_or_forward_operand 42 yield %1 : !pdl.operation - } : !pdl.operation + } // expected-remark @below {{succeeded}} test_consume_operand_if_matches_param_or_fail %0[42] } @@ -73,9 +73,9 @@ transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation test_print_remark_at_operand %0, "matched" } @@ -119,11 +119,11 @@ pdl.rewrite %0 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %f = pdl_match @const in %arg1 + %f = pdl_match @const in %arg1 : (!pdl.operation) -> !pdl.operation // CHECK: %{{.+}} = get_closest_isolated_parent %{{.+}} - %m = get_closest_isolated_parent %f + %m = get_closest_isolated_parent %f : (!pdl.operation) -> !pdl.operation test_print_remark_at_operand %m, "parent function" } } @@ -144,12 +144,12 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): // This is necessary to run the transformation on something other than the // top-level module, "alternatives" cannot be run on that. - %0 = pdl_match @match_func in %arg1 - transform.alternatives %0 { + %0 = pdl_match @match_func in %arg1 : (!pdl.operation) -> !pdl.operation + transform.alternatives %0 : !pdl.operation { ^bb2(%arg2: !pdl.operation): %1 = transform.test_produce_param_or_forward_operand 42 // This operation fails, which triggers the next alternative without @@ -182,14 +182,14 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @match_call in %arg1 - %1 = get_closest_isolated_parent %0 + %0 = pdl_match @match_call in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{all alternatives failed}} - transform.alternatives %1 { + transform.alternatives %1 : !pdl.operation { ^bb2(%arg2: !pdl.operation): - %2 = transform.pdl_match @match_call in %arg2 + %2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation // expected-remark @below {{applying}} transform.test_emit_remark_and_erase_operand %2, "applying" {fail_after_erase} } @@ -215,24 +215,24 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @match_call in %arg1 - %1 = get_closest_isolated_parent %0 - transform.alternatives %1 { + %0 = pdl_match @match_call in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + transform.alternatives %1 : !pdl.operation { ^bb2(%arg2: !pdl.operation): - %2 = transform.pdl_match @match_call in %arg2 + %2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation // expected-remark @below {{applying}} transform.test_emit_remark_and_erase_operand %2, "applying" {fail_after_erase} }, { ^bb2(%arg2: !pdl.operation): - %2 = transform.pdl_match @match_call in %arg2 + %2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %2, "still here" // This alternative succeeds. }, { ^bb2(%arg2: !pdl.operation): // This alternative is never run, so we must not have a remark here. - %2 = transform.pdl_match @match_call in %arg2 + %2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation transform.test_emit_remark_and_erase_operand %2, "should not happen" {fail_after_erase} } } @@ -258,18 +258,18 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @match_call in %arg1 - %1 = get_closest_isolated_parent %0 - transform.alternatives %1 { + %0 = pdl_match @match_call in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + transform.alternatives %1 : !pdl.operation { ^bb2(%arg2: !pdl.operation): - %2 = transform.pdl_match @match_call in %arg2 + %2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation // expected-remark @below {{applying}} transform.test_emit_remark_and_erase_operand %2, "applying" {fail_after_erase} }, { ^bb2(%arg2: !pdl.operation): - %2 = transform.pdl_match @match_call in %arg2 + %2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation // expected-remark @below {{applying second time}} transform.test_emit_remark_and_erase_operand %2, "applying second time" } @@ -294,13 +294,13 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @match_call in %arg1 - %1 = get_closest_isolated_parent %0 - %2 = transform.alternatives %1 -> !pdl.operation { + %0 = pdl_match @match_call in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.alternatives %1 : !pdl.operation -> !pdl.operation { ^bb2(%arg2: !pdl.operation): - %3 = transform.pdl_match @match_call in %arg2 + %3 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation // expected-remark @below {{applying}} transform.test_emit_remark_and_erase_operand %3, "applying" {fail_after_erase} %4 = transform.test_produce_param_or_forward_operand 43 @@ -335,7 +335,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): // expected-error @below {{scope must not contain the transforms being applied}} - transform.alternatives %arg1 { + transform.alternatives %arg1 : !pdl.operation { ^bb2(%arg2: !pdl.operation): %0 = transform.test_produce_param_or_forward_operand 42 transform.test_consume_operand_if_matches_param_or_fail %0[43] @@ -367,12 +367,12 @@ } - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = transform.pdl_match @match_const in %arg1 + %0 = transform.pdl_match @match_const in %arg1 : (!pdl.operation) -> !pdl.operation %1 = transform.loop.get_parent_for %0 // expected-error @below {{only isolated-from-above ops can be alternative scopes}} - alternatives %1 { + alternatives %1 : !pdl.operation { ^bb2(%arg2: !pdl.operation): } } @@ -394,9 +394,9 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation // expected-error @below {{applications of transform.test_wrong_number_of_results expected to produce 3 results (actually produced 1).}} // expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}} // expected-note @below {{Producing 3 null results is allowed if the use case warrants it.}} @@ -422,9 +422,9 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation // expected-error @below {{applications of transform.test_wrong_number_of_multi_results expected to produce 1 results (actually produced 0)}} // expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}} // expected-note @below {{Producing 1 null results is allowed if the use case warrants it.}} @@ -450,9 +450,9 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation // Transform matches 3 ops and produces 2 results. %1:2 = transform.test_correct_number_of_multi_results %0 } @@ -474,9 +474,9 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation // Transform fails to match any but still produces 2 results. %1:2 = transform.test_correct_number_of_multi_results %0 } @@ -499,9 +499,9 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation // expected-error @below {{unexpected application of transform.test_mixed_null_and_non_null_results produces both null and non null results.}} transform.test_mixed_null_and_non_null_results %0 } @@ -536,11 +536,11 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @addi in %arg1 - %1 = pdl_match @subi in %arg1 - %2 = merge_handles %0, %1 + %0 = pdl_match @addi in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = pdl_match @subi in %arg1 : (!pdl.operation) -> !pdl.operation + %2 = merge_handles %0, %1 : !pdl.operation test_print_remark_at_operand %2, "matched" } } @@ -563,9 +563,9 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation // expected-error @below {{failed to apply}} transform.test_mixed_sucess_and_silenceable %0 } @@ -587,9 +587,9 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(suppress) { + transform.sequence %arg0 : !pdl.operation failures(suppress) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation // Not expecting error here because we are suppressing it. // expected-remark @below {{foo}} test_emit_remark_and_erase_operand %0, "foo" {fail_after_erase} @@ -612,9 +612,9 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation // expected-error @below {{silenceable error}} // expected-remark @below {{foo}} test_emit_remark_and_erase_operand %0, "foo" {fail_after_erase} @@ -637,13 +637,13 @@ pdl.rewrite %2 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): - %0 = pdl_match @func in %arg1 - %1 = replicate num(%0) %arg1 + %0 = pdl_match @func in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = replicate num(%0) %arg1 : !pdl.operation, !pdl.operation // expected-remark @below {{2}} test_print_number_of_associated_payload_ir_ops %1 - %2 = replicate num(%0) %1 + %2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation // expected-remark @below {{4}} test_print_number_of_associated_payload_ir_ops %2 } @@ -668,10 +668,10 @@ pdl.rewrite %0 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %f = pdl_match @const in %arg1 - transform.foreach %f { + %f = pdl_match @const in %arg1 : (!pdl.operation) -> !pdl.operation + transform.foreach %f : !pdl.operation { ^bb2(%arg2: !pdl.operation): // expected-remark @below {{1}} transform.test_print_number_of_associated_payload_ir_ops %arg2 @@ -714,12 +714,12 @@ pdl.rewrite %0 with "transform.dialect" } - transform.sequence %arg0 failures(propagate) { + transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %f = pdl_match @execute_region in %arg1 - %results = transform.foreach %f -> !pdl.operation { + %f = pdl_match @execute_region in %arg1 : (!pdl.operation) -> !pdl.operation + %results = transform.foreach %f : !pdl.operation -> !pdl.operation { ^bb2(%arg2: !pdl.operation): - %g = transform.pdl_match @const in %arg2 + %g = transform.pdl_match @const in %arg2 : (!pdl.operation) -> !pdl.operation transform.yield %g : !pdl.operation } @@ -741,7 +741,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %addi = transform.structured.match ops{["arith.addi"]} in %arg1 - %muli = get_producer_of_operand %addi[0] + %muli = get_producer_of_operand %addi[0] : (!pdl.operation) -> !pdl.operation transform.test_print_remark_at_operand %muli, "found muli" } @@ -757,7 +757,7 @@ ^bb1(%arg1: !pdl.operation): %muli = transform.structured.match ops{["arith.muli"]} in %arg1 // expected-error @below {{could not find a producer for operand number: 0 of}} - %bbarg = get_producer_of_operand %muli[0] + %bbarg = get_producer_of_operand %muli[0] : (!pdl.operation) -> !pdl.operation } @@ -772,12 +772,12 @@ transform.sequence failures(propagate) { ^bb1(%fun: !pdl.operation): %muli = transform.structured.match ops{["arith.muli"]} in %fun - %h:2 = split_handles %muli in [2] + %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) // expected-remark @below {{1}} transform.test_print_number_of_associated_payload_ir_ops %h#0 %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun // expected-error @below {{expected to contain 3 operation handles but it only contains 2 handles}} - %h_2:3 = split_handles %muli_2 in [3] + %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) } // ----- @@ -791,12 +791,12 @@ transform.sequence failures(suppress) { ^bb1(%fun: !pdl.operation): %muli = transform.structured.match ops{["arith.muli"]} in %fun - %h:2 = split_handles %muli in [2] + %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) // expected-remark @below {{1}} transform.test_print_number_of_associated_payload_ir_ops %h#0 %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun // Silenceable failure and all handles are now empty. - %h_2:3 = split_handles %muli_2 in [3] + %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) // expected-remark @below {{0}} transform.test_print_number_of_associated_payload_ir_ops %h_2#0 } @@ -813,9 +813,9 @@ pdl.rewrite %0 with "transform.dialect" } - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @some in %arg1 + %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation %2 = transform.cast %0 : !pdl.operation to !transform.test_dialect_op transform.cast %2 : !transform.test_dialect_op to !pdl.operation } @@ -833,9 +833,9 @@ pdl.rewrite %0 with "transform.dialect" } - sequence %arg0 failures(propagate) { + sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = pdl_match @other in %arg1 + %0 = pdl_match @other in %arg1 : (!pdl.operation) -> !pdl.operation // expected-error @below {{expected the payload operation to belong to the 'test' dialect}} %2 = transform.cast %0 : !pdl.operation to !transform.test_dialect_op transform.cast %2 : !transform.test_dialect_op to !pdl.operation diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -18,21 +18,22 @@ @run def testSequenceOp(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, - [pdl.OperationType.get()]) + [pdl.OperationType.get()], + pdl.OperationType.get()) with InsertionPoint(sequence.body): transform.YieldOp([sequence.bodyTarget]) # CHECK-LABEL: TEST: testSequenceOp - # CHECK: = transform.sequence failures(propagate) { + # CHECK: = transform.sequence -> !pdl.operation failures(propagate) { # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation): # CHECK: yield %[[ARG0]] : !pdl.operation - # CHECK: } : !pdl.operation + # CHECK: } @run def testNestedSequenceOp(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): - nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, sequence.bodyTarget) + nested = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget) with InsertionPoint(nested.body): doubly_nested = transform.SequenceOp( transform.FailurePropagationMode.PROPAGATE, @@ -44,42 +45,42 @@ # CHECK-LABEL: TEST: testNestedSequenceOp # CHECK: transform.sequence failures(propagate) { # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation): - # CHECK: sequence %[[ARG0]] failures(propagate) { + # CHECK: sequence %[[ARG0]] : !pdl.operation failures(propagate) { # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): - # CHECK: = sequence %[[ARG1]] failures(propagate) { + # CHECK: = sequence %[[ARG1]] : !pdl.operation -> !pdl.operation failures(propagate) { # CHECK: ^{{.*}}(%[[ARG2:.+]]: !pdl.operation): # CHECK: yield %[[ARG2]] : !pdl.operation - # CHECK: } : !pdl.operation + # CHECK: } # CHECK: } # CHECK: } @run def testTransformPDLOps(): - withPdl = transform.WithPDLPatternsOp() + withPdl = transform.WithPDLPatternsOp(pdl.OperationType.get()) with InsertionPoint(withPdl.body): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [pdl.OperationType.get()], withPdl.bodyTarget) with InsertionPoint(sequence.body): - match = transform.PDLMatchOp(sequence.bodyTarget, "pdl_matcher") + match = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "pdl_matcher") transform.YieldOp(match) # CHECK-LABEL: TEST: testTransformPDLOps # CHECK: transform.with_pdl_patterns { # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation): - # CHECK: = sequence %[[ARG0]] failures(propagate) { + # CHECK: = sequence %[[ARG0]] : !pdl.operation -> !pdl.operation failures(propagate) { # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] # CHECK: yield %[[RES]] : !pdl.operation - # CHECK: } : !pdl.operation + # CHECK: } # CHECK: } @run def testGetClosestIsolatedParentOp(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): - transform.GetClosestIsolatedParentOp(sequence.bodyTarget) + transform.GetClosestIsolatedParentOp(pdl.OperationType.get(), sequence.bodyTarget) transform.YieldOp() # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp # CHECK: transform.sequence @@ -89,7 +90,7 @@ @run def testMergeHandlesOp(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): transform.MergeHandlesOp([sequence.bodyTarget]) transform.YieldOp() @@ -101,13 +102,13 @@ @run def testReplicateOp(): - with_pdl = transform.WithPDLPatternsOp() + with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get()) with InsertionPoint(with_pdl.body): sequence = transform.SequenceOp( - transform.FailurePropagationMode.PROPAGATE, with_pdl.bodyTarget) + transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget) with InsertionPoint(sequence.body): - m1 = transform.PDLMatchOp(sequence.bodyTarget, "first") - m2 = transform.PDLMatchOp(sequence.bodyTarget, "second") + m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first") + m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second") transform.ReplicateOp(m1, [m2]) transform.YieldOp() # CHECK-LABEL: TEST: testReplicateOp diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py --- a/mlir/test/python/dialects/transform_loop_ext.py +++ b/mlir/test/python/dialects/transform_loop_ext.py @@ -18,7 +18,7 @@ @run def getParentLoop(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): loop.GetParentForOp(sequence.bodyTarget, num_loops=2) transform.YieldOp() @@ -29,7 +29,7 @@ @run def loopOutline(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo") transform.YieldOp() @@ -40,7 +40,7 @@ @run def loopPeel(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): loop.LoopPeelOp(sequence.bodyTarget) transform.YieldOp() @@ -50,7 +50,7 @@ @run def loopPipeline(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3) transform.YieldOp() @@ -62,7 +62,7 @@ @run def loopUnroll(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): loop.LoopUnrollOp(sequence.bodyTarget, factor=42) transform.YieldOp() diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -18,7 +18,7 @@ @run def testDecompose(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): structured.DecomposeOp(sequence.bodyTarget) transform.YieldOp() @@ -29,7 +29,7 @@ @run def testGeneralize(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): structured.GeneralizeOp(sequence.bodyTarget) transform.YieldOp() @@ -40,7 +40,7 @@ @run def testInterchange(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): structured.InterchangeOp( sequence.bodyTarget, @@ -56,7 +56,7 @@ @run def testMultitileSizes(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): structured.MultiTileSizesOp( sequence.bodyTarget, dimension=1, target_size=42) @@ -70,7 +70,7 @@ @run def testPad(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): structured.PadOp( sequence.bodyTarget, @@ -90,7 +90,7 @@ @run def testScalarize(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): structured.ScalarizeOp(sequence.bodyTarget) transform.YieldOp() @@ -100,7 +100,7 @@ @run def testSplit(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42) structured.SplitOp( @@ -113,7 +113,7 @@ @run def testTileCompact(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) transform.YieldOp() @@ -125,7 +125,7 @@ @run def testTileAttributes(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) attr = ArrayAttr.get( [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]]) ichange = ArrayAttr.get( @@ -141,7 +141,7 @@ @run def testTileZero(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): structured.TileOp( sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) @@ -154,13 +154,13 @@ @run def testTileDynamic(): - with_pdl = transform.WithPDLPatternsOp() + with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get()) with InsertionPoint(with_pdl.body): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget) with InsertionPoint(sequence.body): - m1 = transform.PDLMatchOp(sequence.bodyTarget, "first") - m2 = transform.PDLMatchOp(sequence.bodyTarget, "second") + m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first") + m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second") structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) transform.YieldOp() # CHECK-LABEL: TEST: testTileDynamic @@ -171,7 +171,7 @@ @run def testVectorize(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) transform.YieldOp() diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8398,6 +8398,7 @@ deps = [ ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":PDLDialectTdFiles", ":SideEffectInterfacesTdFiles",