diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -163,52 +163,22 @@ const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); auto ifOp = cast(op); - auto thenYieldOp = cast(ifOp.thenBlock()->getTerminator()); - auto elseYieldOp = cast(ifOp.elseBlock()->getTerminator()); - // Reconcile type mismatches between then/else branches by inserting memref - // casts. - SmallVector thenResults, elseResults; - bool insertedCast = false; - for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) { - Value thenValue = thenYieldOp.getResults()[i]; - Value elseValue = elseYieldOp.getResults()[i]; - if (thenValue.getType() == elseValue.getType()) { - thenResults.push_back(thenValue); - elseResults.push_back(elseValue); + // Compute bufferized result types. + SmallVector newTypes; + for (Value result : ifOp.getResults()) { + if (!result.getType().isa()) { + newTypes.push_back(result.getType()); continue; } - - // Type mismatch between then/else yield value. Cast both to a memref type - // with a fully dynamic layout map. - auto thenMemrefType = thenValue.getType().cast(); - auto elseMemrefType = elseValue.getType().cast(); - if (thenMemrefType.getMemorySpaceAsInt() != - elseMemrefType.getMemorySpaceAsInt()) - return op->emitError("inconsistent memory space on then/else branches"); - rewriter.setInsertionPoint(thenYieldOp); - BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout( - ifOp.getResultTypes()[i].cast(), - thenMemrefType.getMemorySpaceAsInt()); - thenResults.push_back(rewriter.create( - thenYieldOp.getLoc(), memrefType, thenValue)); - rewriter.setInsertionPoint(elseYieldOp); - elseResults.push_back(rewriter.create( - elseYieldOp.getLoc(), memrefType, elseValue)); - insertedCast = true; - } - - if (insertedCast) { - rewriter.setInsertionPoint(thenYieldOp); - rewriter.replaceOpWithNewOp(thenYieldOp, thenResults); - rewriter.setInsertionPoint(elseYieldOp); - rewriter.replaceOpWithNewOp(elseYieldOp, elseResults); + auto bufferType = bufferization::getBufferType(result, options); + if (failed(bufferType)) + return failure(); + newTypes.push_back(*bufferType); } // Create new op. rewriter.setInsertionPoint(ifOp); - ValueRange resultsValueRange(thenResults); - TypeRange newTypes(resultsValueRange); auto newIfOp = rewriter.create(ifOp.getLoc(), newTypes, ifOp.getCondition(), /*withElseRegion=*/true); @@ -223,6 +193,55 @@ return success(); } + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto ifOp = cast(op); + auto thenYieldOp = cast(ifOp.thenBlock()->getTerminator()); + auto elseYieldOp = cast(ifOp.elseBlock()->getTerminator()); + assert(value.getDefiningOp() == op && "invalid valid"); + + // Determine buffer types of the true/false branches. + auto opResult = value.cast(); + auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); + auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); + BaseMemRefType thenBufferType, elseBufferType; + if (thenValue.getType().isa()) { + // True branch was already bufferized. + thenBufferType = thenValue.getType().cast(); + } else { + auto maybeBufferType = + bufferization::getBufferType(thenValue, options, fixedTypes); + if (failed(maybeBufferType)) + return failure(); + thenBufferType = *maybeBufferType; + } + if (elseValue.getType().isa()) { + // False branch was already bufferized. + elseBufferType = elseValue.getType().cast(); + } else { + auto maybeBufferType = + bufferization::getBufferType(elseValue, options, fixedTypes); + if (failed(maybeBufferType)) + return failure(); + elseBufferType = *maybeBufferType; + } + + // Best case: Both branches have the exact same buffer type. + if (thenBufferType == elseBufferType) + return thenBufferType; + + // Memory space mismatch. + if (thenBufferType.getMemorySpaceAsInt() != + elseBufferType.getMemorySpaceAsInt()) + return op->emitError("inconsistent memory space on then/else branches"); + + // Layout maps are different: Promote to fully dynamic layout map. + return getMemRefTypeWithFullyDynamicLayout( + opResult.getType().cast(), + thenBufferType.getMemorySpaceAsInt()); + } + BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { // IfOp results are equivalent to their corresponding yield values if both @@ -973,9 +992,12 @@ if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; - if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + // In case of scf::ForOp / scf::IfOp, we may have to cast the value + // before yielding it. + // TODO: Do the same for scf::WhileOp. + if (isa(yieldOp->getParentOp())) { FailureOr resultType = bufferization::getBufferType( - forOp.getRegionIterArgs()[it.index()], options); + yieldOp->getParentOp()->getResult(it.index()), options); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir @@ -5,9 +5,9 @@ // bufferized. %0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<10xf32> %1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<10xf32> - // expected-error @+2 {{inconsistent memory space on then/else branches}} - // expected-error @+1 {{failed to bufferize op}} + // expected-error @+1 {{inconsistent memory space on then/else branches}} %r = scf.if %c -> tensor<10xf32> { + // expected-error @+1 {{failed to bufferize op}} scf.yield %0 : tensor<10xf32> } else { scf.yield %1 : tensor<10xf32>