diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h @@ -9,8 +9,8 @@ #ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H #define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H -#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -12,10 +12,12 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" -include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">; + def GetParentForOp : Op]> { @@ -30,12 +32,13 @@ }]; let arguments = - (ins PDL_Operation:$target, + (ins TransformTypeInterface:$target, DefaultValuedAttr, "1">:$num_loops); - let results = (outs PDL_Operation:$parent); + let results = (outs TransformTypeInterface:$parent); - let assemblyFormat = "$target attr-dict"; + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; } def LoopOutlineOp : Op:$fail_if_already_divisible); // TODO: Return both the peeled loop and the remainder loop. - let results = (outs PDL_Operation:$transformed); + let results = (outs TransformTypeInterface:$transformed); - let assemblyFormat = "$target attr-dict"; + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -131,12 +139,13 @@ pipelined loops, which can be empty. }]; - let arguments = (ins PDL_Operation:$target, + let arguments = (ins Transform_ScfForOp:$target, DefaultValuedAttr:$iteration_interval, DefaultValuedAttr:$read_latency); - let results = (outs PDL_Operation:$transformed); + let results = (outs TransformTypeInterface:$transformed); - let assemblyFormat = "$target attr-dict"; + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -165,10 +174,10 @@ removed after a full unrolling. }]; - let arguments = (ins PDL_Operation:$target, + let arguments = (ins Transform_ScfForOp:$target, ConfinedAttr:$factor); - let assemblyFormat = "$target attr-dict"; + let assemblyFormat = "$target attr-dict `:` type($target)"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( diff --git a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt @@ -11,7 +11,6 @@ MLIRAffineDialect MLIRFuncDialect MLIRIR - MLIRPDLDialect MLIRSCFDialect MLIRSCFTransforms MLIRSCFUtils diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" @@ -239,8 +238,6 @@ using Base::Base; void init() { - declareDependentDialect(); - declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -5,7 +5,6 @@ try: from ..ir import * from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ..dialects import pdl except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -28,13 +27,14 @@ """Extension for GetParentForOp.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, num_loops: int = 1, ip=None, loc=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), num_loops=_get_int64_attr(num_loops, default_value=1), ip=ip, @@ -45,13 +45,14 @@ """Extension for LoopOutlineOp.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, func_name: Union[str, StringAttr], ip=None, loc=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), func_name=(func_name if isinstance(func_name, StringAttr) else StringAttr.get(func_name)), @@ -63,13 +64,14 @@ """Extension for LoopPeelOp.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, fail_if_already_divisible: Union[bool, BoolAttr] = False, ip=None, loc=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), fail_if_already_divisible=(fail_if_already_divisible if isinstance( fail_if_already_divisible, BoolAttr) else @@ -82,6 +84,7 @@ """Extension for LoopPipelineOp.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, iteration_interval: Optional[Union[int, IntegerAttr]] = None, @@ -89,7 +92,7 @@ ip=None, loc=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), iteration_interval=_get_int64_attr(iteration_interval, default_value=1), read_latency=_get_int64_attr(read_latency, default_value=10), diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -51,7 +51,8 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} - transform.loop.peel %loops#0 + %loop = transform.cast %loops#0 : !pdl.operation to !transform.op<"scf.for"> + transform.loop.peel %loop : (!transform.op<"scf.for">) -> !pdl.operation } } diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -14,11 +14,11 @@ transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %match_name = transform.structured.match ops{["arith.constant"]} in %arg1 - transform.test_print_remark_at_operand %match_name, "matched op name" + transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation transform.test_consume_operand %match_name %match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1 - transform.test_print_remark_at_operand %match_attr, "matched attr name" + transform.test_print_remark_at_operand %match_attr, "matched attr name" : !pdl.operation transform.test_consume_operand %match_attr } } @@ -38,7 +38,7 @@ ^bb1(%arg1: !pdl.operation): %match_name = transform.structured.match ops{["arith.constant"]} filter_result_type = f32 in %arg1 - transform.test_print_remark_at_operand %match_name, "matched op name" + transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation transform.test_consume_operand %match_name } } @@ -69,7 +69,7 @@ ops{["linalg.generic"]} attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %arg1 - transform.test_print_remark_at_operand %match_attr, "matched complex attr" + transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation transform.test_consume_operand %match_attr %no_match = transform.structured.match diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -33,5 +33,5 @@ %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 %1 = transform.memref.multibuffer %0 {factor = 2 : i64} // Verify that the returned handle is usable. - transform.test_print_remark_at_operand %1, "transformed" + transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation } diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -21,12 +21,12 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 // CHECK: = transform.loop.get_parent_for - %1 = transform.loop.get_parent_for %0 - %2 = transform.loop.get_parent_for %0 { num_loops = 2 } - %3 = transform.loop.get_parent_for %0 { num_loops = 3 } - transform.test_print_remark_at_operand %1, "third loop" - transform.test_print_remark_at_operand %2, "second loop" - transform.test_print_remark_at_operand %3, "first loop" + %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for"> + %2 = transform.loop.get_parent_for %0 { num_loops = 2 } : (!pdl.operation) -> !transform.op<"scf.for"> + %3 = transform.loop.get_parent_for %0 { num_loops = 3 } : (!pdl.operation) -> !transform.op<"scf.for"> + transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"scf.for"> + transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"scf.for"> + transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"scf.for"> } } @@ -44,7 +44,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 // expected-error @below {{could not find an 'scf.for' parent}} - %1 = transform.loop.get_parent_for %0 + %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for"> } } @@ -85,9 +85,9 @@ sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 - %1 = transform.loop.get_parent_for %0 + %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for"> // CHECK: = transform.loop.outline %{{.*}} - transform.loop.outline %1 {func_name = "foo"} + transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation } } @@ -115,7 +115,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["scf.while"]} in %arg1 // expected-error @below {{failed to outline}} - transform.loop.outline %0 {func_name = "foo"} + transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation } } @@ -145,8 +145,8 @@ sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 - %1 = transform.loop.get_parent_for %0 - transform.loop.peel %1 + %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for"> + transform.loop.peel %1 : (!transform.op<"scf.for">) -> !pdl.operation } } @@ -181,10 +181,10 @@ sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addf"]} in %arg1 - %1 = transform.loop.get_parent_for %0 - %2 = transform.loop.pipeline %1 + %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for"> + %2 = transform.loop.pipeline %1 : (!transform.op<"scf.for">) -> !pdl.operation // Verify that the returned handle is usable. - transform.test_print_remark_at_operand %2, "transformed" + transform.test_print_remark_at_operand %2, "transformed" : !pdl.operation } } @@ -208,8 +208,8 @@ sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["arith.addi"]} in %arg1 - %1 = transform.loop.get_parent_for %0 - transform.loop.unroll %1 { factor = 4 } + %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for"> + transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for"> } } diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -23,7 +23,7 @@ // expected-note @below {{invalidated by this transform op that consumes its operand #0}} test_consume_operand %1 // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} - test_print_remark_at_operand %0, "remark" + test_print_remark_at_operand %0, "remark" : !pdl.operation } } @@ -57,7 +57,7 @@ %2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation // expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}} test_consume_operand %2 - test_print_remark_at_operand %0, "remark" + test_print_remark_at_operand %0, "remark" : !pdl.operation } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -76,7 +76,7 @@ sequence %arg0 : !pdl.operation failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation - test_print_remark_at_operand %0, "matched" + test_print_remark_at_operand %0, "matched" : !pdl.operation } pdl.pattern @some : benefit(1) { @@ -124,7 +124,7 @@ %f = pdl_match @const in %arg1 : (!pdl.operation) -> !pdl.operation // CHECK: %{{.+}} = get_closest_isolated_parent %{{.+}} %m = get_closest_isolated_parent %f : (!pdl.operation) -> !pdl.operation - test_print_remark_at_operand %m, "parent function" + test_print_remark_at_operand %m, "parent function" : !pdl.operation } } @@ -227,7 +227,7 @@ }, { ^bb2(%arg2: !pdl.operation): %2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation - transform.test_print_remark_at_operand %2, "still here" + transform.test_print_remark_at_operand %2, "still here" : !pdl.operation // This alternative succeeds. }, { ^bb2(%arg2: !pdl.operation): @@ -370,7 +370,7 @@ sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.pdl_match @match_const in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.loop.get_parent_for %0 + %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !pdl.operation // expected-error @below {{only isolated-from-above ops can be alternative scopes}} alternatives %1 : !pdl.operation { ^bb2(%arg2: !pdl.operation): @@ -541,7 +541,7 @@ %0 = pdl_match @addi in %arg1 : (!pdl.operation) -> !pdl.operation %1 = pdl_match @subi in %arg1 : (!pdl.operation) -> !pdl.operation %2 = merge_handles %0, %1 : !pdl.operation - test_print_remark_at_operand %2, "matched" + test_print_remark_at_operand %2, "matched" : !pdl.operation } } @@ -675,7 +675,7 @@ ^bb2(%arg2: !pdl.operation): // expected-remark @below {{1}} transform.test_print_number_of_associated_payload_ir_ops %arg2 - transform.test_print_remark_at_operand %arg2, "transform applied" + transform.test_print_remark_at_operand %arg2, "transform applied" : !pdl.operation } } } @@ -725,7 +725,7 @@ // expected-remark @below {{3}} transform.test_print_number_of_associated_payload_ir_ops %results - transform.test_print_remark_at_operand %results, "transform applied" + transform.test_print_remark_at_operand %results, "transform applied" : !pdl.operation } } @@ -742,7 +742,7 @@ ^bb1(%arg1: !pdl.operation): %addi = transform.structured.match ops{["arith.addi"]} in %arg1 %muli = get_producer_of_operand %addi[0] : (!pdl.operation) -> !pdl.operation - transform.test_print_remark_at_operand %muli, "found muli" + transform.test_print_remark_at_operand %muli, "found muli" : !pdl.operation } // ----- diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -69,10 +69,11 @@ : Op]> { let arguments = (ins - Arg:$operand, StrAttr:$message); - let assemblyFormat = "$operand `,` $message attr-dict"; + let assemblyFormat = + "$operand `,` $message attr-dict `:` type($operand)"; let cppNamespace = "::mlir::test"; } diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py --- a/mlir/test/python/dialects/transform_loop_ext.py +++ b/mlir/test/python/dialects/transform_loop_ext.py @@ -18,9 +18,10 @@ @run def getParentLoop(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, + [], pdl.OperationType.get()) with InsertionPoint(sequence.body): - loop.GetParentForOp(sequence.bodyTarget, num_loops=2) + loop.GetParentForOp(transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2) transform.YieldOp() # CHECK-LABEL: TEST: getParentLoop # CHECK: = transform.loop.get_parent_for % @@ -29,9 +30,10 @@ @run def loopOutline(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, + [], transform.OperationType.get("scf.for")) with InsertionPoint(sequence.body): - loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo") + loop.LoopOutlineOp(pdl.OperationType.get(), sequence.bodyTarget, func_name="foo") transform.YieldOp() # CHECK-LABEL: TEST: loopOutline # CHECK: = transform.loop.outline % @@ -40,9 +42,10 @@ @run def loopPeel(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, + [], transform.OperationType.get("scf.for")) with InsertionPoint(sequence.body): - loop.LoopPeelOp(sequence.bodyTarget) + loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget) transform.YieldOp() # CHECK-LABEL: TEST: loopPeel # CHECK: = transform.loop.peel % @@ -50,9 +53,10 @@ @run def loopPipeline(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, + [], transform.OperationType.get("scf.for")) with InsertionPoint(sequence.body): - loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3) + loop.LoopPipelineOp(pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3) transform.YieldOp() # CHECK-LABEL: TEST: loopPipeline # CHECK: = transform.loop.pipeline % @@ -62,7 +66,8 @@ @run def loopUnroll(): - sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) + sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, + [], transform.OperationType.get("scf.for")) with InsertionPoint(sequence.body): loop.LoopUnrollOp(sequence.bodyTarget, factor=42) transform.YieldOp()