diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -282,7 +282,7 @@ void transform::TakeAssumedBranchOp::getEffects( SmallVectorImpl &effects) { - consumesHandle(getTarget(), effects); + onlyReadsHandle(getTarget(), effects); modifiesPayload(effects); } diff --git a/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir b/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir --- a/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir +++ b/mlir/test/Dialect/SCF/transform-op-take-assumed-branch.mlir @@ -12,7 +12,6 @@ ^bb0(%arg1: !transform.any_op): %if = transform.structured.match ops{["scf.if"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @+1 {{requires an scf.if op with a single-block `else` region}} transform.scf.take_assumed_branch %if take_else_branch : (!transform.any_op) -> () @@ -20,6 +19,30 @@ // ----- +// CHECK-LABEL: if_no_else +func.func @if_no_else(%cond: i1, %a: index, %b: memref, %c: i8) { + scf.if %cond { + "some_op"(%cond, %b) : (i1, memref) -> () + scf.yield + } + return +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !transform.any_op): + %if = transform.structured.match ops{["scf.if"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %some_op = transform.structured.match ops{["some_op"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + + transform.scf.take_assumed_branch %if : (!transform.any_op) -> () + + // Handle to formerly nested `some_op` is still valid after the transform. + transform.print %some_op: !transform.any_op +} + +// ----- + // CHECK-LABEL: tile_tensor_pad func.func @tile_tensor_pad( %arg0 : tensor, %cst : f32, %low: index, %high: index)