diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td --- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td +++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.td @@ -18,61 +18,4 @@ def Transform_AffineForOp : Transform_ConcreteOpType<"affine.for">; -def AffineGetParentForOp : Op]> { - let summary = - "Gets a handle to the parent 'affine.for' loop of the given operation"; - let description = [{ - Produces a handle to the n-th (default 1) parent `affine.for` loop for each - Payload IR operation associated with the operand. Fails if such a loop - cannot be found. The list of operations associated with the handle contains - parent operations in the same order as the list associated with the operand, - except for operations that are parents to more than one input which are only - present once. - }]; - - let arguments = - (ins TransformTypeInterface:$target, - DefaultValuedAttr, - "1">:$num_loops); - let results = (outs TransformTypeInterface:$parent); - - let assemblyFormat = - "$target attr-dict `:` functional-type(operands, results)"; -} - - -def AffineLoopUnrollOp : Op { - let summary = "Unrolls the given loop with the given unroll factor"; - let description = [{ - Unrolls each loop associated with the given handle to have up to the given - number of loop body copies per iteration. If the unroll factor is larger - than the loop trip count, the latter is used as the unroll factor instead. - - #### Return modes - - This operation ignores non-affine::For ops and drops them in the return. - If all the operations referred to by the `target` PDLOperation unroll - properly, the transform succeeds. Otherwise the transform silently fails. - - Does not return handles as the operation may result in the loop being - removed after a full unrolling. - }]; - - let arguments = (ins Transform_AffineForOp:$target, - ConfinedAttr:$factor); - - let assemblyFormat = "$target attr-dict `:` type($target)"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::AffineForOp target, - ::llvm::SmallVector<::mlir::Operation *> & results, - ::mlir::transform::TransformState & state); - }]; -} - #endif // Affine_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -23,19 +23,20 @@ DeclareOpInterfaceMethods]> { let summary = "Gets a handle to the parent 'for' loop of the given operation"; let description = [{ - Produces a handle to the n-th (default 1) parent `scf.for` loop for each - Payload IR operation associated with the operand. Fails if such a loop - cannot be found. The list of operations associated with the handle contains - parent operations in the same order as the list associated with the operand, - except for operations that are parents to more than one input which are only - present once. + Produces a handle to the n-th (default 1) parent `scf.for` or `affine.for` + (when the affine flag is true) loop for each Payload IR operation + associated with the operand. Fails if such a loop cannot be found. The list + of operations associated with the handle contains parent operations in the + same order as the list associated with the operand, except for operations + that are parents to more than one input which are only present once. }]; let arguments = (ins TransformTypeInterface:$target, DefaultValuedAttr, - "1">:$num_loops); - let results = (outs TransformTypeInterface:$parent); + "1">:$num_loops, + DefaultValuedAttr:$affine); + let results = (outs TransformTypeInterface : $parent); let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; @@ -166,22 +167,23 @@ #### Return modes - This operation ignores non-scf::For ops and drops them in the return. - If all the operations referred to by the `target` PDLOperation unroll - properly, the transform succeeds. Otherwise the transform silently fails. + This operation ignores non-scf::For, non-affine::For ops and drops them in + the return. If all the operations referred to by the `target` PDLOperation + unroll properly, the transform succeeds. Otherwise the transform silently + fails. Does not return handles as the operation may result in the loop being removed after a full unrolling. }]; - let arguments = (ins Transform_ScfForOp:$target, + let arguments = (ins TransformTypeInterface:$target, ConfinedAttr:$factor); let assemblyFormat = "$target attr-dict `:` type($target)"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::scf::ForOp target, + ::mlir::Operation *target, ::llvm::SmallVector<::mlir::Operation *> &results, ::mlir::transform::TransformState &state); }]; diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -22,52 +22,6 @@ }; } // namespace -//===----------------------------------------------------------------------===// -// AffineGetParentForOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::AffineGetParentForOp::apply(transform::TransformResults &results, - transform::TransformState &state) { - SetVector parents; - for (Operation *target : state.getPayloadOps(getTarget())) { - AffineForOp loop; - Operation *current = target; - for (unsigned i = 0, e = getNumLoops(); i < e; ++i) { - loop = current->getParentOfType(); - if (!loop) { - DiagnosedSilenceableFailure diag = emitSilenceableError() - << "could not find an '" - << AffineForOp::getOperationName() - << "' parent"; - diag.attachNote(target->getLoc()) << "target op"; - results.set(getResult().cast(), {}); - return diag; - } - current = loop; - } - parents.insert(loop); - } - results.set(getResult().cast(), parents.getArrayRef()); - return DiagnosedSilenceableFailure::success(); -} - -//===----------------------------------------------------------------------===// -// LoopUnrollOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::AffineLoopUnrollOp::applyToOne(AffineForOp target, - SmallVector &results, - transform::TransformState &state) { - if (failed(loopUnrollByFactor(target, getFactor()))) { - Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note); - diag << "op failed to unroll"; - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); - } - return DiagnosedSilenceableFailure(success()); -} - //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" @@ -30,21 +31,23 @@ //===----------------------------------------------------------------------===// // GetParentForOp //===----------------------------------------------------------------------===// - DiagnosedSilenceableFailure transform::GetParentForOp::apply(transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { - scf::ForOp loop; - Operation *current = target; + Operation *loop, *current = target; for (unsigned i = 0, e = getNumLoops(); i < e; ++i) { - loop = current->getParentOfType(); + loop = getAffine() ? current->getParentOfType() + : current->getParentOfType(); + if (!loop) { - DiagnosedSilenceableFailure diag = emitSilenceableError() - << "could not find an '" - << scf::ForOp::getOperationName() - << "' parent"; + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "could not find an '" + << (getAffine() ? AffineForOp::getOperationName() + : scf::ForOp::getOperationName()) + << "' parent"; diag.attachNote(target->getLoc()) << "target op"; results.set(getResult().cast(), {}); return diag; @@ -215,12 +218,18 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopUnrollOp::applyToOne(scf::ForOp target, +transform::LoopUnrollOp::applyToOne(Operation *op, SmallVector &results, transform::TransformState &state) { - if (failed(loopUnrollByFactor(target, getFactor()))) { - Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note); - diag << "op failed to unroll"; + LogicalResult result(failure()); + if (scf::ForOp scfFor = dyn_cast(op)) + result = loopUnrollByFactor(scfFor, getFactor()); + else if (AffineForOp affineFor = dyn_cast(op)) + result = loopUnrollByFactor(affineFor, getFactor()); + + if (failed(result)) { + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note); + diag << "Op failed to unroll"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } return DiagnosedSilenceableFailure(success()); diff --git a/mlir/test/Dialect/Affine/transform-ops.mlir b/mlir/test/Dialect/Affine/transform-ops.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Affine/transform-ops.mlir +++ /dev/null @@ -1,67 +0,0 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: @get_parent_for_op -func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) { - // expected-remark @below {{first loop}} - affine.for %i = %arg0 to %arg1 { - // expected-remark @below {{second loop}} - affine.for %j = %arg0 to %arg1 { - // expected-remark @below {{third loop}} - affine.for %k = %arg0 to %arg1 { - arith.addi %i, %j : index - } - } - } - return -} - -transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["arith.addi"]} in %arg1 - // CHECK: = transform.affine.get_parent_for - %1 = transform.affine.get_parent_for %0 : (!pdl.operation) -> !transform.op<"affine.for"> - %2 = transform.affine.get_parent_for %0 { num_loops = 2 } : (!pdl.operation) -> !transform.op<"affine.for"> - %3 = transform.affine.get_parent_for %0 { num_loops = 3 } : (!pdl.operation) -> !transform.op<"affine.for"> - transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"affine.for"> - transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"affine.for"> - transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"affine.for"> -} - -// ----- - -func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { - // expected-note @below {{target op}} - arith.addi %arg0, %arg1 : index - return -} - -transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["arith.addi"]} in %arg1 - // expected-error @below {{could not find an 'affine.for' parent}} - %1 = transform.affine.get_parent_for %0 : (!pdl.operation) -> !transform.op<"affine.for"> -} - -// ----- - -func.func @loop_unroll_op() { - %c0 = arith.constant 0 : index - %c42 = arith.constant 42 : index - %c5 = arith.constant 5 : index - // CHECK: affine.for %[[I:.+]] = - // expected-remark @below {{affine for loop}} - affine.for %i = %c0 to %c42 { - // CHECK-COUNT-4: arith.addi - arith.addi %i, %i : index - } - return -} - -transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["arith.addi"]} in %arg1 - %1 = transform.affine.get_parent_for %0 : (!pdl.operation) -> !transform.op<"affine.for"> - transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for"> - transform.affine.unroll %1 { factor = 4 } : !transform.op<"affine.for"> -} - diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -192,3 +192,94 @@ transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for"> } +// ----- + +// CHECK-LABEL: @get_parent_for_op +func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) { + // expected-remark @below {{first loop}} + affine.for %i = %arg0 to %arg1 { + // expected-remark @below {{second loop}} + affine.for %j = %arg0 to %arg1 { + // expected-remark @below {{third loop}} + affine.for %k = %arg0 to %arg1 { + arith.addi %i, %j : index + } + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["arith.addi"]} in %arg1 + // CHECK: = transform.loop.get_parent_for + %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for"> + %2 = transform.loop.get_parent_for %0 { num_loops = 2, affine = true } : (!pdl.operation) -> !transform.op<"affine.for"> + %3 = transform.loop.get_parent_for %0 { num_loops = 3, affine = true } : (!pdl.operation) -> !transform.op<"affine.for"> + transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"affine.for"> + transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"affine.for"> + transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"affine.for"> +} + +// ----- + +func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { + // expected-note @below {{target op}} + arith.addi %arg0, %arg1 : index + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["arith.addi"]} in %arg1 + // expected-error @below {{could not find an 'affine.for' parent}} + %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for"> +} + +// ----- + +func.func @loop_unroll_op() { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : index + %c5 = arith.constant 5 : index + // CHECK: affine.for %[[I:.+]] = + // expected-remark @below {{affine for loop}} + affine.for %i = %c0 to %c42 { + // CHECK-COUNT-4: arith.addi + arith.addi %i, %i : index + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["arith.addi"]} in %arg1 + %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for"> + transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for"> + transform.loop.unroll %1 { factor = 4, affine = true } : !transform.op<"affine.for"> +} + +// ----- + +func.func @test_mixed_loops() { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : index + %c5 = arith.constant 5 : index + scf.for %j = %c0 to %c42 step %c5 { + // CHECK: affine.for %[[I:.+]] = + // expected-remark @below {{affine for loop}} + affine.for %i = %c0 to %c42 { + // CHECK-COUNT-4: arith.addi + arith.addi %i, %i : index + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["arith.addi"]} in %arg1 + %1 = transform.loop.get_parent_for %0 { num_loops = 1, affine = true } : (!pdl.operation) -> !transform.op<"affine.for"> + transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for"> + transform.loop.unroll %1 { factor = 4 } : !transform.op<"affine.for"> +}