diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -296,13 +296,13 @@ return InPlaceSpec::None; return inplaceAttr.getValue() ? InPlaceSpec::True : InPlaceSpec::False; } - // Interestingly, scf::ForOp's bbArg can **always** be viewed inplace from the - // perspective of ops nested under it: + // Interestingly, scf::ForOp's and TiledLoop's bbArg can **always** be viewed + // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an // alloc + optional copy makes the bbArg itself inplaceable. // 2. Or the matching iter operand is bufferized inplace and bbArg just // bufferizes to that too. - if (auto forOp = dyn_cast(bbArg.getOwner()->getParentOp())) + if (isa(bbArg.getOwner()->getParentOp())) return InPlaceSpec::True; // Unknown cases. return InPlaceSpec::None; @@ -359,19 +359,28 @@ isa(op) // clang-format on || (none_of(op->getResultTypes(), isaTensor) && none_of(op->getOperandTypes(), isaTensor)); } +/// Return the OpResult that may bufferize into the same buffer as `opOperand` +/// when the op is bufferized inplace. +/// Return null if no such result exists. +static OpResult getInplaceableOpResult(TiledLoopOp op, OpOperand &opOperand) { + return op.getTiedOpResult(opOperand); +} + /// Return the OpResult that may bufferize into the same buffer as `opOperand` /// when the op is bufferized inplace. /// Return null if no such result exists. @@ -441,8 +450,9 @@ // result(s). .Case( [&](auto op) { return getInplaceableOpResult(op, opOperand); }) // ExtractSliceOp is special, when bufferized inplace it just returns an @@ -469,18 +479,23 @@ return TypeSwitch(result.getDefiningOp()) .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); }) .Case([&](ConstantOp op) { return &op->getOpOperand(0); }) - .Case([&](LinalgOp op) { - return op.getOutputTensorOperands()[result.getResultNumber()]; - }) .Case([&](ExtractSliceOp op) { return &op->getOpOperand(0); }) - .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); }) - .Case([&](vector::TransferWriteOp op) { return &op->getOpOperand(1); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. // TODO: update this. .Case([&](scf::ForOp op) { return &op.getIterOpOperands()[result.getResultNumber()]; }) + .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); }) + .Case([&](LinalgOp op) { + return op.getOutputTensorOperands()[result.getResultNumber()]; + }) + .Case([&](TiledLoopOp op) { + // TODO: TiledLoopOp helper method to avoid leaking impl details. + return &op->getOpOperand(op.getNumControlOperands() + + op.getNumInputs() + result.getResultNumber()); + }) + .Case([&](vector::TransferWriteOp op) { return &op->getOpOperand(1); }) .Default([&](Operation *op) { op->dump(); llvm_unreachable("unexpected defining op"); @@ -528,6 +543,10 @@ // matching bbArg may. if (isa(opOperand.getOwner())) return false; + // TiledLoop alone doesn't bufferize to a memory read, one of the uses of its + // matching bbArg may. + if (isa(opOperand.getOwner())) + return false; // CallOpInterface alone doesn't bufferize to a memory read, one of the uses // of the matching bbArg may. It is the responsibility of the caller to // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be @@ -1361,7 +1380,7 @@ if (getInPlace(opResult) == InPlaceSpec::True) { Value v = lookup(bvm, output); if (!v) - return failure(); + return op.emitError() << "missing buffer"; resultBuffers.push_back(v); continue; } @@ -1398,7 +1417,7 @@ // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need // basis. if (!op.hasTensorSemantics()) - return failure(); + return op->emitError() << "op does not have tensor semantics"; b.setInsertionPoint(op); Location loc = op.getLoc(); @@ -1411,13 +1430,14 @@ } newInputBuffers.push_back(lookup(bvm, opOperand->get())); if (!newInputBuffers.back()) - return failure(); + return op->emitError() + << "missing buffer for operand #" << opOperand->getOperandNumber(); } SmallVector newOutputBuffers; // Try to allocate new buffers depending on op's inplace semantics. if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, aliasInfo))) - return failure(); + return op->emitError() << "failed to allocate buffers for results"; // Clone the newly bufferized op. SmallVector newOperands = newInputBuffers; @@ -1609,7 +1629,7 @@ BufferizationAliasInfo &aliasInfo, GlobalCreator &globalCreator) { if (!constantOp.getType().dyn_cast()) - return failure(); + return constantOp->emitError() << "not a constant tensor"; // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1629,11 +1649,16 @@ static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(dimOp); + if (dimOp.source().getType().isa()) { Value v = lookup(bvm, dimOp.source()); if (!v) - return failure(); - dimOp.sourceMutable().assign(v); + return dimOp.emitError() << "missing buffer"; + dimOp.result().replaceAllUsesWith( + b.create(dimOp.getLoc(), v, dimOp.index())); } return success(); } @@ -1649,10 +1674,12 @@ // Otherwise alloc and copy. b.setInsertionPoint(forOp); for (OpResult opResult : forOp->getResults()) { + if (!opResult.getType().isa()) + continue; // TODO: Atm we bail on unranked TensorType because we don't know how to // alloc an UnrankedMemRefType + its underlying ranked MemRefType. if (!opResult.getType().isa()) - return failure(); + return forOp.emitError() << "unsupported unranked tensor"; OpOperand &opOperand = forOp.getOpOperandForResult(opResult); Value operand = opOperand.get(); Value operandBuffer = lookup(bvm, operand); @@ -1731,7 +1758,8 @@ continue; Value v = lookup(bvm, operand.get()); if (!v) - return failure(); + return returnOp->emitError() + << "missing buffer for result #" << operand.getOperandNumber(); Value returnTensor = b.create(returnOp.getLoc(), v); operand.set(returnTensor); aliasInfo.insertNewBufferEquivalence(returnTensor, v); @@ -1740,6 +1768,96 @@ return success(); } +/// Bufferization for TiledLoopOp.. +static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + + // Allocate output buffers if needed, forward output tensor args to the + // terminator. + Operation *yieldOp = tiledLoopOp.getBody()->getTerminator(); + + // Replace tensor `inputs` args with the buffer ones. + int numLoops = tiledLoopOp.getNumLoops(); + int numInputs = tiledLoopOp.inputs().size(); + int numOutputs = tiledLoopOp.outputs().size(); + int afterInputs = tiledLoopOp.getNumControlOperands() + numInputs; + int bbAfterInputs = tiledLoopOp.getNumLoops() + numInputs; + + auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs()); + auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs()); + + // Add buffers for inputs and the corresponding block arguments. + Block *body = tiledLoopOp.getBody(); + for (auto en : llvm::enumerate(oldInputs)) { + if (!en.value().getType().isa()) + continue; + Value inputBuffer = lookup(bvm, en.value()); + assert(inputBuffer && " missing buffer for operand"); + auto index = en.index(); + tiledLoopOp->insertOperands(afterInputs + index, {inputBuffer}); + auto oldTensorBBArg = body->getArgument(tiledLoopOp.getNumLoops() + index); + auto newBufferBBArg = + body->insertArgument(bbAfterInputs + index, inputBuffer.getType()); + aliasInfo.createAliasInfoEntry(newBufferBBArg); + aliasInfo.insertNewBufferEquivalence(oldTensorBBArg, newBufferBBArg); + map(bvm, oldTensorBBArg, newBufferBBArg); + numInputs++; + } + + // Add buffers for outputs and the corresponding block arguments. + int newBuffersCount = 0; + int outputsEndIndex = + tiledLoopOp.getNumControlOperands() + numInputs + numOutputs; + int bbAfterOutputs = tiledLoopOp.getNumLoops() + numInputs + numOutputs; + + for (auto en : llvm::enumerate(llvm::zip( + oldOutputs, tiledLoopOp->getResults(), yieldOp->getOpOperands()))) { + Value outputTensor = std::get<0>(en.value()); + // TODO: Atm we bail on TensorType because we don't know how to alloc an + // UnrankedMemRefType + its underlying ranked MemRefType. + if (!outputTensor.getType().isa()) + continue; + Value operandBuffer = lookup(bvm, outputTensor); + + const OpResult &opResult = std::get<1>(en.value()); + OpOperand &yieldOperand = std::get<2>(en.value()); + if (getInPlace(opResult) != InPlaceSpec::True) { + auto loc = tiledLoopOp.getLoc(); + Value alloc = createNewAllocDeallocPairForShapedValue( + b, loc, outputTensor, aliasInfo); + // If the tensor comes from `linalg::InitTensorOp`, the value is + // unitialized and we do not need to copy. + // TODO: "matching bbArg does not bufferize to a read" is a more general + // check. + if (!outputTensor.getDefiningOp()) { + b.setInsertionPointAfter(alloc.getDefiningOp()); + b.create(loc, operandBuffer, alloc); + } + operandBuffer = alloc; + } + map(bvm, opResult, operandBuffer); + + tiledLoopOp->insertOperands(outputsEndIndex + newBuffersCount, + operandBuffer); + auto newBbArg = body->insertArgument(bbAfterOutputs + newBuffersCount, + operandBuffer.getType()); + map(bvm, body->getArgument(bbAfterOutputs + newBuffersCount - 1), newBbArg); + + // Set operand of `linalg.yield` to the output tensor. + yieldOperand.set(body->getArgument(bbAfterOutputs + newBuffersCount - 1)); + ++newBuffersCount; + } + + // Update segment sizes. + // TODO: Helper method to avoid leaking impl details. + tiledLoopOp->setAttr( + TiledLoopOp::getOperandSegmentSizeAttr(), + b.getI32VectorAttr({numLoops, numLoops, numLoops, numInputs, + numOutputs + newBuffersCount})); + return success(); +} + /// Bufferize ExtractSliceOp to subview with optional alloc + copy depending on /// whether or not it is marked inplaceable. /// Note that `getInplaceableOpResult` on a ExtractSliceOp always returns null. @@ -1872,7 +1990,7 @@ if (auto readOp = dyn_cast(op.getOperation())) { Value v = lookup(bvm, op.source()); if (!v) - return failure(); + return op.emitError() << "missing buffer"; readOp.sourceMutable().assign(v); return success(); } @@ -1892,7 +2010,7 @@ // canonicalize away with one of it uses. newInputBuffer = lookup(bvm, writeOp.source()); if (!newInputBuffer) - return failure(); + return op.emitError() << "missing buffer"; } // Create a new transfer_write on buffer that doesn't have a return value. @@ -1933,6 +2051,22 @@ return success(); } +/// Bufferization for linalg::YieldOp either does not involve tensors or just +/// results in later canonicalization. In either case it does nothing. +static LogicalResult bufferize(OpBuilder &b, linalg::YieldOp yieldOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(yieldOp); + // No tensors -> success. + if (!llvm::any_of(yieldOp.getOperandTypes(), isaTensor)) + return success(); + // linalg::YieldOp nested under TiledLoop must just canonicalize. + if (yieldOp->getParentOfType()) + return success(); + return yieldOp->emitError() << "unexpected yieldOp"; +} //===----------------------------------------------------------------------===// // Bufferization analyses. //===----------------------------------------------------------------------===// @@ -2043,7 +2177,7 @@ const BufferizationAliasInfo &aliasInfo) { auto parentForOp = yieldOp->getParentOfType(); if (!parentForOp) - return failure(); + return yieldOp->emitError() << "not nested under ForOp"; for (OpOperand &operand : yieldOp->getOpOperands()) { OpResult matchingForOpResult = @@ -2057,11 +2191,10 @@ parentForOp.getRegionIterArgForOpOperand(machingForOpOperand); if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg, operand.get())) { - yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand -> Fail the pass\n"; - return failure(); + return yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand -> Fail the pass\n"; } } @@ -2150,10 +2283,10 @@ // Walk in PreOrder to ensure ops with regions are handled before their body. // Since walk has to be PreOrder, we need to erase ops that require it // separately: this is the case for CallOp + // clang-format off SmallVector toErase; - WalkResult result = - funcOp.walk([&](Operation *op) -> WalkResult { - // clang-format off + WalkResult result = funcOp.walk([&](Operation *op) + -> WalkResult { WalkResult result = TypeSwitch(op) // Skip BufferCast and TensorLoad ops. @@ -2161,13 +2294,15 @@ memref::TensorLoadOp>([&](auto) { return success(); }) .Case([&](auto op) { LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo); @@ -2182,23 +2317,23 @@ LDBG("Begin bufferize:\n" << op << '\n'); return bufferize(b, op, bvm, aliasInfo, globalCreator); }) - .Default([&](Operation *op) { + .Default([&](Operation *op) -> LogicalResult { auto isaTensor = [](Type t) { return t.isa(); }; if (any_of(op->getOperandTypes(), isaTensor) || any_of(op->getResultTypes(), isaTensor)) - return failure(); + return op->emitError() << "unsupported op with tensors"; return success(); }); - // clang-format on - // Register post-walk erasure, if necessary. - if (isa(op)) - if (llvm::any_of(op->getOperandTypes(), isaTensor) || - llvm::any_of(op->getResultTypes(), isaTensor)) - toErase.push_back(op); + // Register post-walk erasure, if necessary. + if (isa(op)) + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getResultTypes(), isaTensor)) + toErase.push_back(op); - return result; - }); + return result; + }); + // clang-format on LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); for (Operation *op : toErase) diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -498,3 +498,53 @@ // CHECK: func private @print_memref_f32(memref<*xf32>) func private @print_memref_f32(tensor<*xf32>) + +// ----- + +#TILE_MAP = affine_map<(d0)[s0] -> (3, -d0 + s0)> + +// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)> +// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// CHECK-DAG: #[[$TILE_MAP:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> + +// CHECK: func @tiled_dot( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref +// CHECK-SAME: %[[c:[a-zA-Z0-9]*]]: memref +func @tiled_dot(%A: tensor, %B: tensor, %c: tensor {linalg.inplaceable = true}) -> tensor { + %c3 = constant 3 : index + %c0 = constant 0 : index + + // CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref + %0 = tensor.dim %A, %c0 : tensor + + // CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]] + %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3) + ins (%arg4 = %A: tensor, %arg5 = %B: tensor) + outs (%arg6 = %c: tensor) + iterators["reduction"] + { + // CHECK-NOT: alloc + + %2 = tensor.dim %arg4, %c0 : tensor + %3 = affine.min #TILE_MAP(%arg3)[%2] + // CHECK: %[[SV_A:.*]] = memref.subview {{.*}} + + %4 = tensor.extract_slice %arg4[%arg3] [%3] [1] : tensor to tensor + %5 = tensor.dim %arg5, %c0 : tensor + %6 = affine.min #TILE_MAP(%arg3)[%5] + // CHECK: %[[SV_B:.*]] = memref.subview {{.*}} + + %7 = tensor.extract_slice %arg5[%arg3] [%6] [1] : tensor to tensor + %8 = linalg.dot ins(%4, %7 : tensor, tensor) outs(%arg6 : tensor) -> tensor + // CHECK: linalg.dot ins(%[[SV_A]], %[[SV_B]] : memref, memref) outs(%{{.*}} : memref) + + linalg.yield %8 : tensor + // CHECK: linalg.yield + // CHECK-NOT: tensor + } + + // CHECK: return + // CHECK-NOT: tensor + return %1 : tensor +}