diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -442,6 +442,7 @@ ConstantOp, tensor::DimOp, ExtractSliceOp, + scf::IfOp, scf::ForOp, InsertSliceOp, InitTensorOp, @@ -550,6 +551,16 @@ // clang-format on } +/// Either one of the corresponding yield values from the then/else branches +/// may alias with the result. +static void populateAliasingOpOperands(scf::IfOp op, OpResult result, + SmallVector &operands) { + size_t resultNum = std::distance(op->getOpResults().begin(), + llvm::find(op->getOpResults(), result)); + operands.push_back(&op.thenYield()->getOpOperand(resultNum)); + operands.push_back(&op.elseYield()->getOpOperand(resultNum)); +} + /// Determine which OpOperand* will alias with `result` if the op is bufferized /// in place. Note that multiple OpOperands can may potentially alias with an /// OpResult. E.g.: std.select in the future. @@ -561,6 +572,7 @@ TypeSwitch(result.getDefiningOp()) .Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); }) .Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); }) + .Case([&](scf::IfOp op) { populateAliasingOpOperands(op, result, r); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. // TODO: update this. @@ -712,6 +724,19 @@ if (bbArg.getType().isa()) createAliasInfoEntry(bbArg); }); + + // The return value of an scf::IfOp aliases with both yield values. + rootOp->walk([&](scf::IfOp ifOp) { + if (ifOp->getNumResults() > 0) { + for (auto it : llvm::zip(ifOp.thenYield().results(), + ifOp.elseYield().results(), ifOp.results())) { + aliasInfo.unionSets(std::get<0>(it), std::get<1>(it)); + aliasInfo.unionSets(std::get<0>(it), std::get<2>(it)); + equivalentInfo.unionSets(std::get<0>(it), std::get<1>(it)); + equivalentInfo.unionSets(std::get<0>(it), std::get<2>(it)); + } + } + }); } /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the @@ -815,13 +840,28 @@ } /// Starting from `value`, follow the use-def chain in reverse, always selecting -/// the corresponding aliasing OpOperand. Try to find and return a Value for -/// which `condition` evaluates to true. +/// the aliasing OpOperands. Find and return Values for which `condition` +/// evaluates to true. OpOperands of such matching Values are not traversed any +/// further. +/// +/// When reaching the end of a chain (BlockArgument or Value without aliasing +/// OpOperands), also return the last Value of that chain. +/// +/// Example: /// -/// When reaching the end of the chain (BlockArgument or Value without aliasing -/// OpOperands), return the last Value of the chain. +/// 8 +/// | +/// 6* 7* +-----+----+ +/// | | | | +/// 2* 3 4* 5 +/// | | | | +/// +----------+----------+----------+ +/// | +/// 1 /// -/// Note: The returned SetVector contains exactly one element. +/// In the above example, Values with a star satisfy the condition. When +/// starting the traversal from Value 1, the resulting SetVector is: +/// { 2, 7, 8, 5 } static llvm::SetVector findValueInReverseUseDefChain(Value value, std::function condition) { @@ -842,18 +882,22 @@ continue; } - assert(opOperands.size() == 1 && "multiple OpOperands not supported yet"); - workingSet.insert(opOperands.front()->get()); + for (OpOperand *o : opOperands) + workingSet.insert(o->get()); } return result; } -/// Find the Value (result) of the last preceding write of a given Value. +/// Find the Value of the last preceding write of a given Value. /// /// Note: Unknown ops are handled conservatively and assumed to be writes. /// Furthermore, BlockArguments are also assumed to be writes. There is no /// analysis across block boundaries. +/// +/// Note: To simplify the analysis, scf.if ops are considered writes. Treating +/// a non-writing op as a writing op may introduce unnecessary out-of-place +/// bufferizations, but is always safe from a correctness point of view. static Value findLastPrecedingWrite(Value value) { SetVector result = findValueInReverseUseDefChain(value, [](Value value) { @@ -862,6 +906,8 @@ return true; if (!hasKnownBufferizationAliasingBehavior(op)) return true; + if (isa(op)) + return true; SmallVector opOperands = getAliasingOpOperand(value.cast()); @@ -892,6 +938,44 @@ condition); } +/// Return true if the two given operations are in mutually exclusive scf::IfOp +/// branches. +static bool insideMutuallyExclusiveBlocks(Operation *a, Operation *b) { + Block *block = a->getBlock(); + Operation *ancestorA = a; + Operation *ancestorB; + while (!(ancestorB = block->findAncestorOpInBlock(*b))) { + ancestorA = block->getParentOp(); + assert(ancestorA && "could not get parent"); + block = ancestorA->getBlock(); + } + + if (ancestorA != ancestorB) + return false; + + auto ifOp = dyn_cast(ancestorA); + if (!ifOp) + return false; + + return static_cast(ifOp.thenBlock()->findAncestorOpInBlock(*a)) != + static_cast(ifOp.thenBlock()->findAncestorOpInBlock(*b)); +} + +/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors +/// properly dominates `b` and `b` is not inside `a`. +static bool happensBefore(Operation *a, Operation *b, + const DominanceInfo &domInfo) { + do { + // TODO: Instead of isProperAncestor + properlyDominates, we should use + // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false) + if (a->isProperAncestor(b)) + return false; + if (domInfo.properlyDominates(a, b)) + return true; + } while ((a = a->getParentOp())); + return false; +} + /// Given sets of uses and writes, return true if there is a RaW conflict under /// the assumption that all given reads/writes alias the same buffer and that /// all given writes bufferize inplace. @@ -916,7 +1000,6 @@ // In the above example, if uRead is the OpOperand of reading_op, lastWrite // is %0. Note that operations that create an alias but do not write (such // as ExtractSliceOp) are skipped. - // TODO: With branches this should probably be a list of Values. Value lastWrite = findLastPrecedingWrite(uRead->get()); // Look for conflicting memory writes. Potential conflicts are writes to an @@ -930,21 +1013,34 @@ LDBG("Found potential conflict:\n"); LDBG("READ = #" << uRead->getOperandNumber() << " of " << printOperationInfo(readingOp) << "\n"); - LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n"); LDBG("CONFLICTING WRITE = #" << uConflictingWrite->getOperandNumber() << " of " << printOperationInfo(conflictingWritingOp) << "\n"); // No conflict if the readingOp dominates conflictingWritingOp, i.e., the // write is not visible when reading. - if (domInfo.properlyDominates(readingOp, conflictingWritingOp)) + if (happensBefore(readingOp, conflictingWritingOp, domInfo)) continue; - // No conflict if the conflicting write happens before the last write. + // No conflict is the same use is the read and the conflicting write. A + // use cannot conflict with itself. + if (uConflictingWrite == uRead) + continue; + + if (insideMutuallyExclusiveBlocks(readingOp, conflictingWritingOp)) + continue; + + LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n"); + + // No conflict if the conflicting write happens before the last + // write. if (Operation *writingOp = lastWrite.getDefiningOp()) { - if (domInfo.properlyDominates(conflictingWritingOp, writingOp)) + if (happensBefore(conflictingWritingOp, writingOp, domInfo)) // conflictingWritingOp happens before writingOp. No conflict. continue; + // No conflict if conflictingWritingOp is contained in writingOp. + if (writingOp->isProperAncestor(conflictingWritingOp)) + continue; } else { auto bbArg = lastWrite.cast(); Block *block = bbArg.getOwner(); @@ -959,11 +1055,6 @@ if (getAliasingOpResult(*uConflictingWrite) == lastWrite) continue; - // No conflict is the same use is the read and the conflicting write. A - // use cannot conflict with itself. - if (uConflictingWrite == uRead) - continue; - // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If // uRead is an InsertSliceOp... if (auto insertSliceOp = dyn_cast(readingOp)) { @@ -1410,15 +1501,27 @@ OpBuilder::InsertionGuard guard(b); Operation *op = result.getOwner(); SmallVector aliasingOperands = getAliasingOpOperand(result); - // TODO: Support multiple OpOperands. - assert(aliasingOperands.size() == 1 && - "more than 1 OpOperand not supported yet"); + assert(!aliasingOperands.empty() && "could not get aliasing OpOperand"); Value operand = aliasingOperands.front()->get(); Value operandBuffer = lookup(bvm, operand); assert(operandBuffer && "operand buffer not found"); + // Make sure that all OpOperands are the same buffer. If this is not the case, + // we would have to materialize a memref value. + if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) { + return lookup(bvm, o->get()) == operandBuffer; + })) { + op->emitError("result buffer is ambiguous"); + return Value(); + } // If bufferizing out-of-place, allocate a new buffer. - if (getInPlace(result) != InPlaceSpec::True) { + bool needCopy = + getInPlace(result) != InPlaceSpec::True && !isa(op); + if (needCopy) { + // Ops such as scf::IfOp can currently not bufferize out-of-place. + assert( + aliasingOperands.size() == 1 && + "ops with multiple aliasing OpOperands cannot bufferize out-of-place"); Location loc = op->getLoc(); // Allocate the result buffer. Value resultBuffer = @@ -1752,6 +1855,31 @@ return success(); } +static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + + for (OpResult opResult : ifOp->getResults()) { + if (!opResult.getType().isa()) + continue; + // TODO: Atm we bail on unranked TensorType because we don't know how to + // alloc an UnrankedMemRefType + its underlying ranked MemRefType. + assert(opResult.getType().isa() && + "unsupported unranked tensor"); + + Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + if (!resultBuffer) + return failure(); + + aliasInfo.createAliasInfoEntry(resultBuffer); + map(bvm, opResult, resultBuffer); + } + + return success(); +} + /// FuncOp always creates TensorToMemRef ops. static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, BlockAndValueMapping &bvm, @@ -2013,7 +2141,6 @@ getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo); if (!dstMemref) return failure(); - auto dstMemrefType = dstMemref.getType().cast(); Value srcMemref = lookup(bvm, insertSliceOp.source()); @@ -2102,6 +2229,9 @@ return success(); } + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) + return success(); + scf::ForOp forOp = dyn_cast(yieldOp->getParentOp()); if (!forOp) return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp"); @@ -2348,7 +2478,7 @@ return success(); }) .Case( [&](auto op) { LDBG("Begin bufferize:\n" << op << '\n'); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -912,3 +912,292 @@ return %15 : tensor<62x90xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +// scf.if cases +//===----------------------------------------------------------------------===// + +// This example passes analysis, but it fails when bufferizing. +// CHECK-LABEL: func @scf_if_inplace1 +func @scf_if_inplace1(%t1: tensor {linalg.inplaceable = true}, + %t2: tensor {linalg.inplaceable = true}, + %cond: i1) -> tensor { + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + scf.yield %t2 : tensor + } + return %r : tensor +} + +// ----- + +// CHECK-LABEL: func @scf_if_inplace2 +func @scf_if_inplace2(%t1: tensor {linalg.inplaceable = true}, + %v: vector<5xf32>, %idx: index, + %cond: i1) -> tensor { + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t2 : tensor + } + return %r : tensor +} + +// ----- + +// CHECK-LABEL: func @scf_if_inplace3 +func @scf_if_inplace3(%t1: tensor {linalg.inplaceable = true}, + %v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index, + %cond: i1) -> tensor { + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor to tensor + %r = scf.if %cond -> (tensor) { + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %t2 = vector.transfer_write %v1, %e[%idx] : vector<5xf32>, tensor + scf.yield %t2 : tensor + } else { + // Writing the same tensor through an alias. This is OK. + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t3 : tensor + } + return %r : tensor +} + +// ----- + +// CHECK-LABEL: func @scf_if_in_place4 +func @scf_if_in_place4(%t1: tensor {linalg.inplaceable = true}, + %v: vector<5xf32>, %idx: index, + %cond: i1, %cond2: i1) -> (tensor, vector<10xf32>) { + %cst = constant 0.0 : f32 + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t2 : tensor + } + %r_alias = scf.if %cond2 -> (tensor) { + // Reading %r is OK. No conflict. + scf.yield %r : tensor + } else { + scf.yield %r : tensor + } + %v2 = vector.transfer_read %r_alias[%idx], %cst : tensor, vector<10xf32> + return %r_alias, %v2 : tensor, vector<10xf32> +} + +// ----- + +// CHECK-LABEL: func @scf_if_inplace5 +func @scf_if_inplace5(%t1: tensor {linalg.inplaceable = true}, + %idx: index, %cond: i1) -> tensor { + %r = scf.if %cond -> (tensor) { + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor to tensor + scf.yield %e : tensor + } else { + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %f = tensor.extract_slice %t1[%idx][%idx][1] : tensor to tensor + scf.yield %f : tensor + } + + // Inserting into an equivalent tensor at the same offset. This bufferizes + // inplace. + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor into tensor + return %r2 : tensor +} + +// ----- + +// CHECK-LABEL: func @scf_if_inplace6 +func @scf_if_inplace6(%t1: tensor {linalg.inplaceable = true}, + %v1: vector<5xf32>, %v2: vector<5xf32>, + %v3: vector<5xf32>, %idx: index, + %cond: i1, %cond2: i1) -> tensor { + // Test nested scf.if ops. + %r = scf.if %cond -> (tensor) { + %t2 = scf.if %cond2 -> (tensor) { + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %t3 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t3 : tensor + } else { + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %t4 = vector.transfer_write %v3, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t4 : tensor + } + scf.yield %t2 : tensor + } else { + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t3 : tensor + } + return %r : tensor +} + +// ----- + +// CHECK-LABEL: func @scf_if_inplace7 +func @scf_if_inplace7(%t1: tensor {linalg.inplaceable = true}, + %v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index, + %idx2: index, %cond: i1) -> (tensor, vector<5xf32>) { + %cst = constant 0.0 : f32 + %r, %v_r2 = scf.if %cond -> (tensor, vector<5xf32>) { + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %t2 = vector.transfer_write %v1, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t2, %v1 : tensor, vector<5xf32> + } else { + // Writing the same tensor through an alias. + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor + // Read the original value of %t1. This requires the write in this branch + // to be out-of-place. But the write in the other branch can still be + // inplace. + %v_r = vector.transfer_read %t1[%idx2], %cst : tensor, vector<5xf32> + scf.yield %t3, %v_r : tensor, vector<5xf32> + } + return %r, %v_r2 : tensor, vector<5xf32> +} + +// ----- + +// CHECK-LABEL: func @scf_if_out_of_place1a +func @scf_if_out_of_place1a(%t1: tensor {linalg.inplaceable = true}, + %idx: index, %idx2: index, + %cond: i1) -> tensor { + %r = scf.if %cond -> (tensor) { + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor to tensor + scf.yield %e : tensor + } else { + scf.yield %t1 : tensor + } + + // Reading from and writing to the same tensor via different args. This is a + // conflict. + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor into tensor + return %r2 : tensor +} + +// ----- + +// CHECK-LABEL: func @scf_if_out_of_place1b +func @scf_if_out_of_place1b(%t1: tensor {linalg.inplaceable = true}, + %idx: index, %idx2: index, %idx3: index, + %cond: i1) -> tensor { + %r = scf.if %cond -> (tensor) { + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor to tensor + scf.yield %e : tensor + } else { + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor to tensor + scf.yield %f : tensor + } + + // Reading from and writing to the same tensor via different args. This is a + // conflict. In contrast to scf_if_out_of_place1a, the fact that %r aliases + // with %t1 is only detected when analyzing the tensor.extract_slices. That's + // why the tensor.insert_slice is inplace and the two extract_slices are + // out-of-place. + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor into tensor + return %r2 : tensor +} + +// ----- + +// CHECK-LABEL: func @scf_if_out_of_place1c +func @scf_if_out_of_place1c(%t1: tensor {linalg.inplaceable = true}, + %idx: index, %idx2: index, %cond: i1) -> tensor { + %r = scf.if %cond -> (tensor) { + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %e = tensor.extract_slice %t1[%idx][%idx][1] : tensor to tensor + scf.yield %e : tensor + } else { + // TODO: This one could bufferize inplace, but the analysis is too restrictive. + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %f = tensor.extract_slice %t1[%idx2][%idx2][1] : tensor to tensor + scf.yield %f : tensor + } + + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor into tensor + return %r2 : tensor +} + +// ----- + +// CHECK-LABEL: func @scf_if_out_of_place2 +func @scf_if_out_of_place2(%t1: tensor {linalg.inplaceable = true}, + %v: vector<5xf32>, %idx: index, + %cond: i1) -> (tensor, vector<10xf32>) { + %cst = constant 0.0 : f32 + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t2 : tensor + } + + // Read the old value of %t1. Forces the transfer_write to bufferize + // out-of-place. + %v2 = vector.transfer_read %t1[%idx], %cst : tensor, vector<10xf32> + return %r, %v2 : tensor, vector<10xf32> +} + +// ----- + +// CHECK-LABEL: func @scf_if_out_of_place3 +func @scf_if_out_of_place3(%t1: tensor {linalg.inplaceable = true}, + %v: vector<5xf32>, %idx: index, + %cond: i1, %cond2: i1) -> (tensor, vector<10xf32>) { + %cst = constant 0.0 : f32 + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t2 : tensor + } + %t1_alias = scf.if %cond2 -> (tensor) { + // scf.yield bufferizes to a read. That is a conflict in this example. + scf.yield %t1 : tensor + } else { + scf.yield %t1 : tensor + } + %v2 = vector.transfer_read %t1_alias[%idx], %cst : tensor, vector<10xf32> + return %r, %v2 : tensor, vector<10xf32> +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -151,3 +151,17 @@ } return %r: i32 } + +// ----- + +func @scf_if_inplace1(%t1: tensor {linalg.inplaceable = true}, + %t2: tensor {linalg.inplaceable = true}, + %cond: i1) -> tensor { + // expected-error @+1 {{result buffer is ambiguous}} + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + scf.yield %t2 : tensor + } + return %r : tensor +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -755,3 +755,25 @@ return %r1 : tensor } +// ----- + +// CHECK-LABEL: func @scf_if_inplace( +// CHECK-SAME: %[[cond:.*]]: i1, %[[t1:.*]]: memref, %[[v:.*]]: vector +func @scf_if_inplace(%cond: i1, + %t1: tensor {linalg.inplaceable = true}, + %v: vector<5xf32>, %idx: index) -> tensor { + + // CHECK: scf.if %[[cond]] { + // CHECK-NEXT: } else { + // CHECK-NEXT: vector.transfer_write %[[v]], %[[t1]] + // CHECK-NEXT: } + // CHECK-NEXT: return + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor + scf.yield %t2 : tensor + } + return %r : tensor +} +