diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -296,13 +296,13 @@ return InPlaceSpec::None; return inplaceAttr.getValue() ? InPlaceSpec::True : InPlaceSpec::False; } - // Interestingly, scf::ForOp's bbArg can **always** be viewed inplace from the - // perspective of ops nested under it: + // Interestingly, scf::ForOp's and TiledLoop's bbArg can **always** be viewed + // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an // alloc + optional copy makes the bbArg itself inplaceable. // 2. Or the matching iter operand is bufferized inplace and bbArg just // bufferizes to that too. - if (auto forOp = dyn_cast(bbArg.getOwner()->getParentOp())) + if (isa(bbArg.getOwner()->getParentOp())) return InPlaceSpec::True; // Unknown cases. return InPlaceSpec::None; @@ -359,19 +359,28 @@ isa(op) // clang-format on || (none_of(op->getResultTypes(), isaTensor) && none_of(op->getOperandTypes(), isaTensor)); } +/// Return the OpResult that may bufferize into the same buffer as `opOperand` +/// when the op is bufferized inplace. +/// Return null if no such result exists. +static OpResult getInplaceableOpResult(TiledLoopOp op, OpOperand &opOperand) { + return op.getTiedOpResult(opOperand); +} + /// Return the OpResult that may bufferize into the same buffer as `opOperand` /// when the op is bufferized inplace. /// Return null if no such result exists. @@ -441,8 +450,9 @@ // result(s). .Case( [&](auto op) { return getInplaceableOpResult(op, opOperand); }) // ExtractSliceOp is special, when bufferized inplace it just returns an @@ -469,18 +479,23 @@ return TypeSwitch(result.getDefiningOp()) .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); }) .Case([&](ConstantOp op) { return &op->getOpOperand(0); }) - .Case([&](LinalgOp op) { - return op.getOutputTensorOperands()[result.getResultNumber()]; - }) .Case([&](ExtractSliceOp op) { return &op->getOpOperand(0); }) - .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); }) - .Case([&](vector::TransferWriteOp op) { return &op->getOpOperand(1); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. // TODO: update this. .Case([&](scf::ForOp op) { return &op.getIterOpOperands()[result.getResultNumber()]; }) + .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); }) + .Case([&](LinalgOp op) { + return op.getOutputTensorOperands()[result.getResultNumber()]; + }) + .Case([&](TiledLoopOp op) { + // TODO: TiledLoopOp helper method to avoid leaking impl details. + return &op->getOpOperand(op.getNumControlOperands() + + op.getNumInputs() + result.getResultNumber()); + }) + .Case([&](vector::TransferWriteOp op) { return &op->getOpOperand(1); }) .Default([&](Operation *op) { op->dump(); llvm_unreachable("unexpected defining op"); @@ -528,6 +543,10 @@ // matching bbArg may. if (isa(opOperand.getOwner())) return false; + // TiledLoop alone doesn't bufferize to a memory read, one of the uses of its + // matching bbArg may. + if (isa(opOperand.getOwner())) + return false; // CallOpInterface alone doesn't bufferize to a memory read, one of the uses // of the matching bbArg may. It is the responsibility of the caller to // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be @@ -1361,7 +1380,7 @@ if (getInPlace(opResult) == InPlaceSpec::True) { Value v = lookup(bvm, output); if (!v) - return failure(); + return op.emitError() << "missing buffer"; resultBuffers.push_back(v); continue; } @@ -1398,7 +1417,7 @@ // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need // basis. if (!op.hasTensorSemantics()) - return failure(); + return op->emitError() << "op does not have tensor semantics"; b.setInsertionPoint(op); Location loc = op.getLoc(); @@ -1411,13 +1430,14 @@ } newInputBuffers.push_back(lookup(bvm, opOperand->get())); if (!newInputBuffers.back()) - return failure(); + return op->emitError() + << "missing buffer for operand #" << opOperand->getOperandNumber(); } SmallVector newOutputBuffers; // Try to allocate new buffers depending on op's inplace semantics. if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, aliasInfo))) - return failure(); + return op->emitError() << "failed to allocate buffers for results"; // Clone the newly bufferized op. SmallVector newOperands = newInputBuffers; @@ -1609,7 +1629,7 @@ BufferizationAliasInfo &aliasInfo, GlobalCreator &globalCreator) { if (!constantOp.getType().dyn_cast()) - return failure(); + return constantOp->emitError() << "not a constant tensor"; // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1629,11 +1649,16 @@ static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(dimOp); + if (dimOp.source().getType().isa()) { Value v = lookup(bvm, dimOp.source()); if (!v) - return failure(); - dimOp.sourceMutable().assign(v); + return dimOp.emitError() << "missing buffer"; + dimOp.result().replaceAllUsesWith( + b.create(dimOp.getLoc(), v, dimOp.index())); } return success(); } @@ -1649,10 +1674,12 @@ // Otherwise alloc and copy. b.setInsertionPoint(forOp); for (OpResult opResult : forOp->getResults()) { + if (!opResult.getType().isa()) + continue; // TODO: Atm we bail on unranked TensorType because we don't know how to // alloc an UnrankedMemRefType + its underlying ranked MemRefType. if (!opResult.getType().isa()) - return failure(); + return forOp.emitError() << "unsupported unranked tensor"; OpOperand &opOperand = forOp.getOpOperandForResult(opResult); Value operand = opOperand.get(); Value operandBuffer = lookup(bvm, operand); @@ -1731,7 +1758,8 @@ continue; Value v = lookup(bvm, operand.get()); if (!v) - return failure(); + return returnOp->emitError() + << "missing buffer for result #" << operand.getOperandNumber(); Value returnTensor = b.create(returnOp.getLoc(), v); operand.set(returnTensor); aliasInfo.insertNewBufferEquivalence(returnTensor, v); @@ -1740,6 +1768,135 @@ return success(); } +/// Bufferization for TiledLoopOp.. +static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Allocate output buffers if needed, forward output tensor args to the + // terminator. + Operation *yieldOp = tiledLoopOp.getBody()->getTerminator(); + Block *body = tiledLoopOp.getBody(); + + // Take copies of the old input and output operands, so we can insert inplace + // easily. + auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs()); + auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs()); + + int numLoops = tiledLoopOp.getNumLoops(); + int numControlOperands = tiledLoopOp.getNumControlOperands(); + + // Add buffers for outputs and the corresponding block arguments. + // Keep separate iterators to increment without further leaking impl. details. + // Start with outputs to avoid interference from new input buffers. + int numNewOutputBuffers = 0; + int resultIndex = 0; + int oldOutputBBArgIndex = numLoops + oldInputs.size(); + int nextOutputBBArgIndex = numLoops + oldInputs.size() + oldOutputs.size(); + int nextOutputOperandIndex = + numControlOperands + oldInputs.size() + oldOutputs.size(); + for (Value oldOutputTensor : oldOutputs) { + if (!oldOutputTensor.getType().isa()) { + // Skip and increment the old bbarg index only. + ++oldOutputBBArgIndex; + // Do not increment resultIndex as only tensors are returned. + // TODO: better interface to avoid leaking such impl details. + continue; + } + + assert(oldOutputTensor.getType().isa() && + "bufferizable output must be a ranked tensor"); + + Value outputBuffer = lookup(bvm, oldOutputTensor); + const OpResult &opResult = tiledLoopOp->getResult(resultIndex); + OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); + // If the result is not inplaceable, need to allocate a copy for it. + if (getInPlace(opResult) != InPlaceSpec::True) { + auto loc = tiledLoopOp.getLoc(); + Value alloc = createNewAllocDeallocPairForShapedValue( + b, loc, oldOutputTensor, aliasInfo); + // If the tensor comes from `linalg::InitTensorOp`, the value is + // unitialized and we do not need to copy. + // TODO: "matching bbArg does not bufferize to a read" is a more general + // check. + if (!oldOutputTensor.getDefiningOp()) { + b.setInsertionPointAfter(alloc.getDefiningOp()); + b.create(loc, outputBuffer, alloc); + } + outputBuffer = alloc; + } + // Insert mapping and aliasing info. + aliasInfo.createAliasInfoEntry(outputBuffer); + aliasInfo.insertNewBufferEquivalence(opResult, outputBuffer); + map(bvm, opResult, outputBuffer); + + // Insert new operand and bbArg. + tiledLoopOp->insertOperands(nextOutputOperandIndex, outputBuffer); + BlockArgument newBufferBBArg = + body->insertArgument(nextOutputBBArgIndex, outputBuffer.getType()); + BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex); + // Insert mapping and aliasing info. + aliasInfo.createAliasInfoEntry(newBufferBBArg); + aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg); + map(bvm, oldTensorBBArg, newBufferBBArg); + + // Set operand of `linalg.yield` to the bbArg so it just canonicalizes away + // later. + yieldOperand.set(oldTensorBBArg); + + // Increment indices. + ++numNewOutputBuffers; + ++resultIndex; + ++oldOutputBBArgIndex; + ++nextOutputBBArgIndex; + ++nextOutputOperandIndex; + } + + // Add buffers for inputs and the corresponding block arguments. + // Keep separate iterators to increment without further leaking impl. details. + int numNewInputBuffers = 0; + int oldInputBBArgIndex = numLoops; + int nextInputBBArgIndex = numLoops + oldInputs.size(); + int nextInputOperandIndex = numControlOperands + oldInputs.size(); + for (Value oldInputTensor : oldInputs) { + if (!oldInputTensor.getType().isa()) { + // Skip and increment the old bbarg index only. + ++oldInputBBArgIndex; + continue; + } + + Value inputBuffer = lookup(bvm, oldInputTensor); + assert(inputBuffer && " missing buffer for operand"); + + // Insert new operand and bbArg. + tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer); + BlockArgument newBufferBBArg = + body->insertArgument(nextInputBBArgIndex, inputBuffer.getType()); + BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex); + + // Insert mapping and aliasing info. + aliasInfo.createAliasInfoEntry(newBufferBBArg); + aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg); + map(bvm, oldTensorBBArg, newBufferBBArg); + + // Increment indices. + ++numNewInputBuffers; + ++oldInputBBArgIndex; + ++nextInputBBArgIndex; + ++nextInputOperandIndex; + } + + // Update segment sizes. + // TODO: Helper method to avoid leaking impl details. + tiledLoopOp->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + b.getI32VectorAttr( + {numLoops, numLoops, numLoops, + static_cast(oldInputs.size()) + numNewInputBuffers, + static_cast(oldOutputs.size()) + numNewOutputBuffers})); + + return success(); +} + /// Bufferize ExtractSliceOp to subview with optional alloc + copy depending on /// whether or not it is marked inplaceable. /// Note that `getInplaceableOpResult` on a ExtractSliceOp always returns null. @@ -1872,7 +2029,7 @@ if (auto readOp = dyn_cast(op.getOperation())) { Value v = lookup(bvm, op.source()); if (!v) - return failure(); + return op.emitError() << "missing buffer"; readOp.sourceMutable().assign(v); return success(); } @@ -1892,7 +2049,7 @@ // canonicalize away with one of it uses. newInputBuffer = lookup(bvm, writeOp.source()); if (!newInputBuffer) - return failure(); + return op.emitError() << "missing buffer"; } // Create a new transfer_write on buffer that doesn't have a return value. @@ -1933,6 +2090,22 @@ return success(); } +/// Bufferization for linalg::YieldOp either does not involve tensors or just +/// results in later canonicalization. In either case it does nothing. +static LogicalResult bufferize(OpBuilder &b, linalg::YieldOp yieldOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(yieldOp); + // No tensors -> success. + if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor)) + return success(); + // linalg::YieldOp nested under TiledLoop must just canonicalize. + if (yieldOp->getParentOfType()) + return success(); + return yieldOp->emitError() << "unexpected yieldOp"; +} //===----------------------------------------------------------------------===// // Bufferization analyses. //===----------------------------------------------------------------------===// @@ -2043,7 +2216,7 @@ const BufferizationAliasInfo &aliasInfo) { auto parentForOp = yieldOp->getParentOfType(); if (!parentForOp) - return failure(); + return yieldOp->emitError() << "not nested under ForOp"; for (OpOperand &operand : yieldOp->getOpOperands()) { OpResult matchingForOpResult = @@ -2057,11 +2230,10 @@ parentForOp.getRegionIterArgForOpOperand(machingForOpOperand); if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg, operand.get())) { - yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand -> Fail the pass\n"; - return failure(); + return yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand -> Fail the pass\n"; } } @@ -2150,10 +2322,10 @@ // Walk in PreOrder to ensure ops with regions are handled before their body. // Since walk has to be PreOrder, we need to erase ops that require it // separately: this is the case for CallOp + // clang-format off SmallVector toErase; - WalkResult result = - funcOp.walk([&](Operation *op) -> WalkResult { - // clang-format off + WalkResult result = funcOp.walk([&](Operation *op) + -> WalkResult { WalkResult result = TypeSwitch(op) // Skip BufferCast and TensorLoad ops. @@ -2161,13 +2333,15 @@ memref::TensorLoadOp>([&](auto) { return success(); }) .Case([&](auto op) { LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo); @@ -2182,23 +2356,23 @@ LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo, globalCreator); }) - .Default([&](Operation *op) { + .Default([&](Operation *op) -> LogicalResult { auto isaTensor = [](Type t) { return t.isa(); }; if (any_of(op->getOperandTypes(), isaTensor) || any_of(op->getResultTypes(), isaTensor)) - return failure(); + return op->emitError() << "unsupported op with tensors"; return success(); }); - // clang-format on - // Register post-walk erasure, if necessary. - if (isa(op)) - if (llvm::any_of(op->getOperandTypes(), isaTensor) || - llvm::any_of(op->getResultTypes(), isaTensor)) - toErase.push_back(op); + // Register post-walk erasure, if necessary. + if (isa(op)) + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getResultTypes(), isaTensor)) + toErase.push_back(op); - return result; - }); + return result; + }); + // clang-format on LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); for (Operation *op : toErase) diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -498,3 +498,60 @@ // CHECK: func private @print_memref_f32(memref<*xf32>) func private @print_memref_f32(tensor<*xf32>) + +// ----- + +func private @some_use(memref) + +#TILE_MAP = affine_map<(d0)[s0] -> (3, -d0 + s0)> + +// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)> +// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// CHECK-DAG: #[[$TILE_MAP:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> + +// CHECK: func @tiled_dot( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref +// CHECK-SAME: %[[c:[a-zA-Z0-9]*]]: memref +func @tiled_dot(%A: tensor, %B: tensor, %c: tensor {linalg.inplaceable = true}, + %effecting: memref) -> tensor { + %c3 = constant 3 : index + %c0 = constant 0 : index + + // CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref + %0 = tensor.dim %A, %c0 : tensor + + // CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]] + %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3) + ins (%arg4 = %A: tensor, %use = %effecting : memref, %arg5 = %B: tensor) + outs (%arg6 = %c: tensor) + iterators["reduction"] + { + // CHECK-NOT: alloc + + %2 = tensor.dim %arg4, %c0 : tensor + %3 = affine.min #TILE_MAP(%arg3)[%2] + + // CHECK: %[[SV_A:.*]] = memref.subview {{.*}} + %4 = tensor.extract_slice %arg4[%arg3] [%3] [1] : tensor to tensor + %5 = tensor.dim %arg5, %c0 : tensor + %6 = affine.min #TILE_MAP(%arg3)[%5] + + // CHECK: %[[SV_B:.*]] = memref.subview {{.*}} + %7 = tensor.extract_slice %arg5[%arg3] [%6] [1] : tensor to tensor + + // CHECK: linalg.dot ins(%[[SV_A]], %[[SV_B]] : memref, memref) outs(%{{.*}} : memref) + %8 = linalg.dot ins(%4, %7 : tensor, tensor) outs(%arg6 : tensor) -> tensor + + // CHECK: call @some_use(%{{.*}}) : (memref) -> () + call @some_use(%use) : (memref) -> () + + linalg.yield %8 : tensor + // CHECK: linalg.yield + // CHECK-NOT: tensor + } + + // CHECK: return + // CHECK-NOT: tensor + return %1 : tensor +}