diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1397,6 +1397,38 @@ // Bufferization as simple BlockAndValueMapping rewrites. //===----------------------------------------------------------------------===// +/// Return the result buffer (memref) for a given OpResult (tensor). Allocate +/// a new buffer and copy over data from the existing buffer if out-of-place +/// bufferization is necessary. +static Value getResultBuffer(OpBuilder &b, OpResult result, + const BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + OpBuilder::InsertionGuard guard(b); + Operation *op = result.getOwner(); + Optional maybeOperand = getAliasingOpOperand(result); + assert(maybeOperand && "corresponding OpOperand not found"); + Value operand = (*maybeOperand)->get(); + Value operandBuffer = lookup(bvm, operand); + assert(operandBuffer && "operand buffer not found"); + + // If bufferizing out-of-place, allocate a new buffer. + if (getInPlace(result) != InPlaceSpec::True) { + Location loc = op->getLoc(); + // Allocate the result buffer. + Value resultBuffer = + createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo); + if (!isInitTensorOp(operand)) { + // Set insertion point now that potential alloc/dealloc are introduced. + b.setInsertionPoint(op); + b.create(loc, operandBuffer, resultBuffer); + } + return resultBuffer; + } + + // Bufferizing in-place. No need to allocate a new buffer. + return operandBuffer; +} + /// Helper function for LinalgOp bufferization. /// Examines each result and determines whether it bufferizes inplace on an /// operand. @@ -1639,27 +1671,8 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(castOp); - // If castOp is not inPlace, allocate a new buffer. - auto inPlace = getInPlace(castOp->getResult(0)); - Value newBuffer; - if (inPlace != InPlaceSpec::True) { - Location loc = castOp.getLoc(); - // Alloc a copy for `writeOp.source()`, it will become the result buffer. - newBuffer = createNewAllocDeallocPairForShapedValue(b, loc, castOp.source(), - aliasInfo); - if (!isInitTensorOp(castOp.source())) { - // Set insertion point now that potential alloc/dealloc are introduced. - b.setInsertionPoint(castOp); - b.create(loc, lookup(bvm, castOp.source()), newBuffer); - } - } else { - // InPlace write will result in memref.tensor_load(x) which must - // canonicalize away with one of it uses. - newBuffer = lookup(bvm, castOp.source()); - assert(newBuffer && "missing buffer"); - } - - Type sourceType = newBuffer.getType(); + Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo); + Type sourceType = resultBuffer.getType(); auto rankedMemRefType = sourceType.dyn_cast(); auto unrankedMemRefType = sourceType.dyn_cast(); assert(rankedMemRefType || unrankedMemRefType); @@ -1673,7 +1686,8 @@ : ArrayRef{}; Type memRefType = getContiguousOrUnrankedMemRefType( castOp.getResult().getType(), affineMaps, memorySpace); - Value res = b.create(castOp.getLoc(), memRefType, newBuffer); + Value res = + b.create(castOp.getLoc(), memRefType, resultBuffer); aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); map(bvm, castOp.getResult(), res); return success(); @@ -1723,9 +1737,6 @@ // Take a guard before anything else. OpBuilder::InsertionGuard g(b); - // If inPlace, just forward the buffer. - // Otherwise alloc and copy. - Location loc = forOp.getLoc(); for (OpResult opResult : forOp->getResults()) { if (!opResult.getType().isa()) continue; @@ -1733,29 +1744,11 @@ // alloc an UnrankedMemRefType + its underlying ranked MemRefType. assert(opResult.getType().isa() && "unsupported unranked tensor"); + + // TODO: More general: Matching bbArg does not bufferize to a read. + Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + OpOperand &opOperand = forOp.getOpOperandForResult(opResult); - Value operand = opOperand.get(); - Value operandBuffer = lookup(bvm, operand); - Value resultBuffer = operandBuffer; - if (getInPlace(opResult) != InPlaceSpec::True) { - resultBuffer = - createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo); - // If the tensor comes from either: - // - linalg.init_tensor - // - tensor.cast(linalg.init_tensor()) - // Then the value is unitialized and we do not need to copy. This is a - // pragmatic simplification of "matching bbArg does not bufferize to a - // read". - // TODO: "matching bbArg does not bufferize to a read" is a more general - // check. - if (!isInitTensorOp(operand)) { - OpBuilder::InsertionGuard g(b); - // Set insertion point now that potential alloc/dealloc are introduced. - // Copy is inserted just before the forOp. - b.setInsertionPoint(forOp); - b.create(forOp.getLoc(), operandBuffer, resultBuffer); - } - } BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); aliasInfo.createAliasInfoEntry(resultBuffer); aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer); @@ -1872,40 +1865,19 @@ assert(oldOutputTensor.getType().isa() && "bufferizable output must be a ranked tensor"); - Value outputBuffer = lookup(bvm, oldOutputTensor); const OpResult &opResult = tiledLoopOp->getResult(resultIndex); OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); - // If the result is not inplaceable, need to allocate a copy for it. - if (getInPlace(opResult) != InPlaceSpec::True) { - auto loc = tiledLoopOp.getLoc(); - Value alloc = createNewAllocDeallocPairForShapedValue( - b, loc, oldOutputTensor, aliasInfo); - // If the tensor comes from either: - // - linalg.init_tensor - // - tensor.cast(linalg.init_tensor()) - // Then the value is unitialized and we do not need to copy. This is a - // pragmatic simplification of "matching bbArg does not bufferize to a - // read". - // TODO: "matching bbArg does not bufferize to a read" is a more general - // check. - if (!isInitTensorOp(oldOutputTensor)) { - OpBuilder::InsertionGuard g(b); - // Set insertion point now that potential alloc/dealloc are introduced. - // Copy is inserted just before the tiledLoopOp. - b.setInsertionPoint(tiledLoopOp); - b.create(loc, outputBuffer, alloc); - } - outputBuffer = alloc; - } + Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + // Insert mapping and aliasing info. - aliasInfo.createAliasInfoEntry(outputBuffer); - aliasInfo.insertNewBufferEquivalence(opResult, outputBuffer); - map(bvm, opResult, outputBuffer); + aliasInfo.createAliasInfoEntry(resultBuffer); + aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer); + map(bvm, opResult, resultBuffer); // Insert new operand and bbArg. - tiledLoopOp->insertOperands(nextOutputOperandIndex, outputBuffer); + tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer); BlockArgument newBufferBBArg = - body->insertArgument(nextOutputBBArgIndex, outputBuffer.getType()); + body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType()); BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex); // Insert mapping and aliasing info. aliasInfo.createAliasInfoEntry(newBufferBBArg); @@ -2035,25 +2007,15 @@ LDBG("bufferize: " << *insertSliceOp << '\n'); Location loc = insertSliceOp.getLoc(); - Value dstMemref = lookup(bvm, insertSliceOp.dest()); - if (!dstMemref) - return failure(); - auto inPlace = getInPlace(insertSliceOp->getResult(0)); - if (inPlace != InPlaceSpec::True) { - // Since insert_slice arise from tiling and introducing loops, this - // case is generally a deal breaker. When used with loops, this ends up - // cloning the whole tensor on every single iteration and is a symptom - // of a catastrophically bad scheduling decision. - // TODO: be very loud about it or even consider failing the pass. - // Alloc a copy for `insertSliceOp.dest()`, it will become the result - // buffer. - Value newDstMemref = createNewAllocDeallocPairForShapedValue( - b, loc, insertSliceOp.dest(), aliasInfo); - // Set insertion point now that potential alloc/dealloc are introduced. - b.setInsertionPoint(insertSliceOp); - b.create(insertSliceOp.getLoc(), dstMemref, newDstMemref); - dstMemref = newDstMemref; - } + // Since insert_slice arise from tiling and introducing loops, this + // case is generally a deal breaker. When used with loops, this ends up + // cloning the whole tensor on every single iteration and is a symptom + // of a catastrophically bad scheduling decision. + // TODO: be very loud about it or even consider failing the pass. + // Alloc a copy for `insertSliceOp.dest()`, it will become the result + // buffer. + Value dstMemref = + getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo); auto dstMemrefType = dstMemref.getType().cast(); Value srcMemref = lookup(bvm, insertSliceOp.source()); @@ -2071,6 +2033,7 @@ // slice is computed out of place into the inplace full tensor. // - The result is not inplace. This is the case where the whole tensor is // cloned and the clone needs to be updated. + auto inPlace = getInPlace(insertSliceOp->getResult(0)); if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp( insertSliceOp) || inPlace != InPlaceSpec::True) { @@ -2109,38 +2072,16 @@ return success(); } - auto inPlace = getInPlace(op->getResult(0)); - auto writeOp = cast(op.getOperation()); - - // If transfer_write is not inPlace, allocate a new buffer. - Value newInputBuffer; - Location loc = op.getLoc(); - if (inPlace != InPlaceSpec::True) { - // Alloc a copy for `writeOp.source()`, it will become the result buffer. - newInputBuffer = createNewAllocDeallocPairForShapedValue( - b, loc, writeOp.source(), aliasInfo); - Value v = lookup(bvm, writeOp.source()); - if (!isInitTensorOp(writeOp.source())) { - // Set insertion point now that potential alloc/dealloc are introduced. - b.setInsertionPoint(op); - b.create(loc, v, newInputBuffer); - } - } else { - // InPlace write will result in memref.tensor_load(x) which must - // canonicalize away with one of it uses. - newInputBuffer = lookup(bvm, writeOp.source()); - assert(newInputBuffer && "missing buffer"); - } - // Create a new transfer_write on buffer that doesn't have a return value. // Leave the previous transfer_write to dead code as it still has uses at // this point. + auto writeOp = cast(op.getOperation()); + Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo); b.create( - loc, writeOp.vector(), newInputBuffer, writeOp.indices(), + op.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), writeOp.permutation_map(), writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); - - map(bvm, op->getResult(0), newInputBuffer); + map(bvm, op->getResult(0), resultBuffer); return success(); }