diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -314,6 +314,23 @@ auto bufferizableOp = cast(op); Block *oldLoopBody = &forOp.getLoopBody().front(); + // Helper function for casting MemRef buffers. + auto castBuffer = [&](Value buffer, Type type) { + assert(type.isa() && "expected BaseMemRefType"); + assert(buffer.getType().isa() && + "expected BaseMemRefType"); + // If the buffer already has the correct type, no cast is needed. + if (buffer.getType() == type) + return buffer; + // TODO: In case `type` has a layout map that is not the fully dynamic + // one, we may not be able to cast the buffer. In that case, the loop + // iter_arg's layout map must be changed (see uses of `castBuffer`). + assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && + "scf.for op bufferization: cast incompatible"); + return rewriter.create(buffer.getLoc(), type, buffer) + .getResult(); + }; + // Indices of all iter_args that have tensor type. These are the ones that // are bufferized. DenseSet indices; @@ -382,9 +399,10 @@ rewriter.setInsertionPoint(yieldOp); SmallVector yieldValues = convert(yieldOp.getResults(), [&](Value val, int64_t index) { - ensureToMemrefOpIsValid(val, initArgs[index].getType()); - Value yieldedVal = rewriter.create( - val.getLoc(), initArgs[index].getType(), val); + Type initArgType = initArgs[index].getType(); + ensureToMemrefOpIsValid(val, initArgType); + Value yieldedVal = + bufferization::lookupBuffer(rewriter, val, state.getOptions()); if (equivalentYields[index]) // Yielded value is equivalent to the corresponding iter_arg bbArg. @@ -392,7 +410,7 @@ // else must be resolved with copies and is potentially inefficient. // By default, such problematic IR would already have been rejected // during `verifyAnalysis`, unless `allow-return-allocs`. - return yieldedVal; + return castBuffer(yieldedVal, initArgType); // It is not certain that the yielded value and the iter_arg bbArg // have the same buffer. Allocate a new buffer and copy. The yielded @@ -412,21 +430,9 @@ (void)copyStatus; assert(succeeded(copyStatus) && "could not create memcpy"); - if (yieldedVal.getType() == yieldedAlloc->getType()) - return *yieldedAlloc; - - // The iter_arg memref type has a layout map. Cast the new buffer to - // the same type. - // TODO: In case the iter_arg has a layout map that is not the fully - // dynamic one, we cannot cast the new buffer. In that case, the - // iter_arg must be changed to the fully dynamic layout map. (And then - // the new buffer can be casted.) - assert(memref::CastOp::areCastCompatible(yieldedAlloc->getType(), - yieldedVal.getType()) && - "scf.for op bufferization: cast incompatible"); - Value casted = rewriter.create( - val.getLoc(), yieldedVal.getType(), *yieldedAlloc); - return casted; + // The iter_arg memref type may have a layout map. Cast the new buffer + // to the same type if needed. + return castBuffer(*yieldedAlloc, initArgType); }); yieldOp.getResultsMutable().assign(yieldValues);