diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h @@ -19,6 +19,7 @@ } // namespace func namespace scf { class ForOp; +class IfOp; } // namespace scf } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -215,4 +215,45 @@ }]; } +def TakeAssumedBranchOp : Op, + TransformOpInterface, TransformEachOpTrait]> { + let description = [{ + Given an scf.if conditional, inject user-defined information that it is + always safe to execute only the if or else branch. + + This is achieved by just replacing the scf.if by the content of one of its + branches. + + This is particularly useful for user-controlled rewriting of conditionals + that exist solely to guard against out-of-bounds behavior. + + At the moment, no assume or assert operation is emitted as it is not always + desirable. In the future, this may be controlled by a dedicated attribute. + + #### Return modes + + The transform only consumes its operand and does not produce any result. + The transform definitely fails if `take_else_branch` is specified and the + `else` region is empty. + }]; + let arguments = (ins TransformHandleTypeInterface:$target, + OptionalAttr:$take_else_branch); + let results = (outs); + + let assemblyFormat = [{ + $target + (`take_else_branch` $take_else_branch^)? + attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::scf::IfOp ifOp, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // SCF_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; @@ -245,6 +246,46 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// TakeAssumedBranchOp +//===----------------------------------------------------------------------===// +/// Replaces the given op with the contents of the given single-block region, +/// using the operands of the block terminator to replace operation results. +static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, + Region ®ion, ValueRange blockArgs = {}) { + assert(llvm::hasSingleElement(region) && "expected single-region block"); + Block *block = ®ion.front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.inlineBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); +} + +DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( + scf::IfOp ifOp, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + TrackingListener listener(state, *this); + IRRewriter rewriter(ifOp->getContext(), &listener); + rewriter.setInsertionPoint(ifOp); + + Region ®ion = + getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); + if (!llvm::hasSingleElement(region)) { + return emitDefiniteFailure() + << "requires an scf.if op with a single-block " + << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region"; + } + replaceOpWithRegion(rewriter, ifOp, region); + return DiagnosedSilenceableFailure::success(); +} + +void transform::TakeAssumedBranchOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir b/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics --allow-unregistered-dialect | FileCheck %s + +func.func @if_no_else(%cond: i1, %a: index, %b: memref, %c: i8) { + scf.if %cond { + "some_op"(%cond, %b) : (i1, memref) -> () + scf.yield + } + return +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %if = transform.structured.match ops{["scf.if"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + + // expected-error @+1 {{requires an scf.if op with a single-block `else` region}} + transform.scf.take_assumed_branch %if take_else_branch + : (!transform.any_op) -> () +} + +// ----- + +// CHECK-LABEL: tile_tensor_pad +func.func @tile_tensor_pad( + %arg0 : tensor, %cst : f32, %low: index, %high: index) + -> tensor<20x40xf32> +{ + // CHECK: scf.forall + // CHECK-NOT: scf.if + // CHECK-NOT: tensor.generate + // CHECK-NOT: else + // CHECK: tensor.pad {{.*}} nofold + %0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] { + ^bb0(%arg9: index, %arg10: index): + tensor.yield %cst : f32 + } : tensor to tensor<20x40xf32> + return %0 : tensor<20x40xf32> +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!transform.any_op) -> !pdl.operation + transform.structured.tile_to_forall_op %0 tile_sizes[1, 1] + + %if = transform.structured.match ops{["scf.if"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.scf.take_assumed_branch %if take_else_branch + : (!transform.any_op) -> () +}