diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -172,6 +172,14 @@ return returnOp; } +/// Return true if `value` is the result of an InitTensorOp or a cast thereof. +static bool isInitTensorOp(Value value) { + tensor::CastOp castOp; + while ((castOp = value.getDefiningOp())) + value = castOp.source(); + return value.getDefiningOp(); +} + //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// @@ -1781,7 +1789,7 @@ // unitialized and we do not need to copy. // TODO: "matching bbArg does not bufferize to a read" is a more general // check. - if (!operand.getDefiningOp()) + if (!isInitTensorOp(operand)) b.create(forOp.getLoc(), operandBuffer, resultBuffer); } BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); @@ -1908,7 +1916,7 @@ // unitialized and we do not need to copy. // TODO: "matching bbArg does not bufferize to a read" is a more general // check. - if (!oldOutputTensor.getDefiningOp()) { + if (!isInitTensorOp(oldOutputTensor)) { b.setInsertionPointAfter(alloc.getDefiningOp()); b.create(loc, outputBuffer, alloc); }