diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -241,115 +242,88 @@ BufferizationState &state) const { auto tiledLoopOp = cast(op); - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - // Allocate output buffers if needed, forward output tensor args to the - // terminator. - Operation *yieldOp = tiledLoopOp.getBody()->getTerminator(); - Block *body = tiledLoopOp.getBody(); - - // Take copies of the old input and output operands, so we can insert - // inplace easily. - auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs()); - auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs()); - - int numLoops = tiledLoopOp.getNumLoops(); - int numControlOperands = tiledLoopOp.getNumControlOperands(); - - // Add buffers for outputs and the corresponding block arguments. - // Keep separate iterators to increment without further leaking impl. - // details. Start with outputs to avoid interference from new input buffers. - int numNewOutputBuffers = 0; - int resultIndex = 0; - int oldOutputBBArgIndex = numLoops + oldInputs.size(); - int nextOutputBBArgIndex = numLoops + oldInputs.size() + oldOutputs.size(); - int nextOutputOperandIndex = - numControlOperands + oldInputs.size() + oldOutputs.size(); - for (Value oldOutputTensor : oldOutputs) { - if (!oldOutputTensor.getType().isa()) { - // Skip and increment the old bbarg index only. - ++oldOutputBBArgIndex; - // Do not increment resultIndex as only tensors are returned. - // TODO: better interface to avoid leaking such impl details. - continue; + // Use IRRewriter instead of OpBuilder because it has additional helper + // functions. + IRRewriter rewriter(op->getContext()); + rewriter.setInsertionPoint(tiledLoopOp); + + // Compute new inputs, outputs and results. + SmallVector newInputs, newOutputs, newResults; + for (Value value : tiledLoopOp.inputs()) { + if (value.getType().isa()) { + newInputs.push_back(state.lookupBuffer(value)); + } else { + newInputs.push_back(value); } - - assert(oldOutputTensor.getType().isa() && - "bufferizable output must be a ranked tensor"); - - const OpResult &opResult = tiledLoopOp->getResult(resultIndex); - OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); - Value resultBuffer = state.getResultBuffer(opResult); - if (!resultBuffer) - return failure(); - - // Insert mapping and aliasing info. - state.mapBuffer(opResult, resultBuffer); - - // Insert new operand and bbArg. - tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer); - BlockArgument newBufferBBArg = - body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType()); - BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex); - // Insert mapping and aliasing info. - state.mapBuffer(oldTensorBBArg, newBufferBBArg); - - // Set operand of `linalg.yield` to the bbArg so it just canonicalizes - // away later. - yieldOperand.set(oldTensorBBArg); - - // Increment indices. - ++numNewOutputBuffers; - ++resultIndex; - ++oldOutputBBArgIndex; - ++nextOutputBBArgIndex; - ++nextOutputOperandIndex; } - - // Add buffers for inputs and the corresponding block arguments. - // Keep separate iterators to increment without further leaking impl. - // details. - int numNewInputBuffers = 0; - int oldInputBBArgIndex = numLoops; - int nextInputBBArgIndex = numLoops + oldInputs.size(); - int nextInputOperandIndex = numControlOperands + oldInputs.size(); - for (Value oldInputTensor : oldInputs) { - if (!oldInputTensor.getType().isa()) { - // Skip and increment the old bbarg index only. - ++oldInputBBArgIndex; - continue; + int nextResultNum = 0; + for (Value value : tiledLoopOp.outputs()) { + if (value.getType().isa()) { + Value buffer = + state.getResultBuffer(tiledLoopOp->getResult(nextResultNum++)); + newOutputs.push_back(buffer); + newResults.push_back(buffer); + } else { + newOutputs.push_back(value); } + } - Value inputBuffer = state.lookupBuffer(oldInputTensor); - - // Insert new operand and bbArg. - tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer); - BlockArgument newBufferBBArg = - body->insertArgument(nextInputBBArgIndex, inputBuffer.getType()); - BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex); + // Create new TiledLoopOp. + auto newTiledLoopOp = rewriter.create( + tiledLoopOp.getLoc(), tiledLoopOp.lowerBound(), + tiledLoopOp.upperBound(), tiledLoopOp.step(), newInputs, newOutputs, + tiledLoopOp.iterator_types(), tiledLoopOp.distribution_types()); + + // Remove terminator. + if (!newTiledLoopOp.getBody()->empty()) + rewriter.eraseOp(tiledLoopOp.getBody()->getTerminator()); + + // Compute new loop body arguments. + SmallVector newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs; + ValueRange newInductionVars = newTiledLoopOp.getInductionVars(); + newBlockArgs.append(newInductionVars.begin(), newInductionVars.end()); + + ValueRange newRegionInArgs = newTiledLoopOp.getRegionInputArgs(); + ValueRange newRegionOutArgs = newTiledLoopOp.getRegionOutputArgs(); + newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end()); + newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end()); + + ValueRange oldRegionInArgs = tiledLoopOp.getRegionInputArgs(); + ValueRange oldRegionOutArgs = tiledLoopOp.getRegionOutputArgs(); + oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end()); + oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end()); + assert(newRegionInArgs.size() == oldRegionInArgs.size() && + "expected same number of input args"); + assert(newRegionOutArgs.size() == oldRegionOutArgs.size() && + "expected same number of output args"); + + for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) { + Value oldArg = std::get<0>(it); + Value newArg = std::get<1>(it); + rewriter.setInsertionPointToStart(newTiledLoopOp->getBlock()); + if (oldArg.getType().isa()) { + newBlockArgs.push_back(rewriter.create( + oldArg.getLoc(), newArg)); + } else { + newBlockArgs.push_back(newArg); + } + } - // Insert mapping and aliasing info. - state.mapBuffer(oldTensorBBArg, newBufferBBArg); + // Move old body into new loop. + rewriter.mergeBlocks(tiledLoopOp.getBody(), newTiledLoopOp.getBody(), + newBlockArgs); - // Increment indices. - ++numNewInputBuffers; - ++oldInputBBArgIndex; - ++nextInputBBArgIndex; - ++nextInputOperandIndex; - } + // Replace previous terminator with a new one that does not yield anything. + Operation *oldTerminator = newTiledLoopOp.getBody()->getTerminator(); + rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody()); + rewriter.create(oldTerminator->getLoc()); + rewriter.eraseOp(oldTerminator); - // Update segment sizes. - // TODO: Helper method to avoid leaking impl details. - tiledLoopOp->setAttr( - TiledLoopOp::getOperandSegmentSizeAttr(), - b.getI32VectorAttr( - {numLoops, numLoops, numLoops, - static_cast(oldInputs.size()) + numNewInputBuffers, - static_cast(oldOutputs.size()) + numNewOutputBuffers})); + // Replace results and delete old op. + state.replaceOp(op, newResults); // Bufferize loop body. - return comprehensive_bufferize::bufferize(&tiledLoopOp.region(), state); + return comprehensive_bufferize::bufferize(newTiledLoopOp.getBody(), state); } }; @@ -372,14 +346,14 @@ BufferizationState &state) const { auto yieldOp = cast(op); - // No tensors -> success. - if (!llvm::any_of(yieldOp.getOperandTypes(), - [](Type t) { return t.isa(); })) - return success(); - // linalg::YieldOp nested under TiledLoop must just canonicalize. - if (yieldOp->getParentOfType()) - return success(); - llvm_unreachable("unexpected yieldOp"); + if (!yieldOp->getParentOfType()) + return yieldOp->emitError( + "expected that linalg.yield terminates a tiled_loop"); + + assert(yieldOp->getOpOperands().empty() && + "expected that linalg.yield was bufferized together with" + " tiled_loop"); + return success(); } }; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6393,6 +6393,7 @@ includes = ["include"], deps = [ ":BufferizableOpInterface", + ":BufferizationDialect", ":IR", ":LinalgOps", ":LinalgStructuredOpsIncGen",