diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -609,17 +609,18 @@ const BufferizationOptions &options); /// Return the buffer type for a given Value (tensor) after bufferization -/// without bufferizing any IR. If at any point during the type computation, the -/// type of a value in `fixedTypes` in required, the mapped type is used. +/// without bufferizing any IR. This function (and not the other overload +/// without `invocationStack`) can be used from `getBufferType` implementations +/// of the `BufferizableOpInterface`. /// /// Note: It should be sufficient to call `getBuffer()->getType()` in most /// cases. However, when a buffer type should be predicted without modifying any /// IR, this function can be used. /// -/// This function is a wrapper around BufferizableOpInterface::getBufferType. -FailureOr -getBufferType(Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes); +/// This function is a wrapper around `BufferizableOpInterface::getBufferType`. +FailureOr getBufferType(Value value, + const BufferizationOptions &options, + SmallVector &invocationStack); /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. @@ -691,7 +692,7 @@ /// places. FailureOr defaultGetBufferType(Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes); + SmallVector &invocationStack); /// This is the default implementation of /// BufferizableOpInterface::resultBufferizesToMemoryWrite. Should not be called diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -494,18 +494,28 @@ This method is useful when the bufferized type of value must be predicted before modifying any IR. + + Implementations may call `bufferization::getBufferType` to compute the + bufferized type of another SSA value. The same (unmodified) + `invocationStack` must be passed to that function. The stack contains + all SSA values for which a buffer type computation is currently in + progress. Implementations may inspect the stack to detect repetitive + computations for the same SSA value. (E.g., when bufferized types of a + loop.) }], /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>", /*methodName=*/"getBufferType", /*args=*/(ins "::mlir::Value":$value, "const ::mlir::bufferization::BufferizationOptions &":$options, - "const ::mlir::DenseMap<::mlir::Value, ::mlir::BaseMemRefType>":$fixedTypes), + "::llvm::SmallVector<::mlir::Value> &":$invocationStack), /*methodBody=*/"", /*defaultImplementation=*/[{ assert(getOwnerOfValue(value) == $_op.getOperation() && "expected that value belongs to this op"); + assert(invocationStack.back() == value && + "inconsistant invocation stack"); return ::mlir::bufferization::detail::defaultGetBufferType( - value, options, fixedTypes); + value, options, invocationStack); }] >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -108,7 +108,7 @@ FailureOr getBufferType( Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes); + SmallVector &invocationStack); RankedTensorType getType() { return ::llvm::cast(getResult().getType()); @@ -388,7 +388,7 @@ FailureOr getBufferType( Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) { + SmallVector &invocationStack) { return ::llvm::cast(getMemref().getType()); } }]; diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -170,13 +170,13 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto selectOp = cast(op); assert(value == selectOp.getResult() && "invalid value"); auto trueType = bufferization::getBufferType(selectOp.getTrueValue(), - options, fixedTypes); + options, invocationStack); auto falseType = bufferization::getBufferType(selectOp.getFalseValue(), - options, fixedTypes); + options, invocationStack); if (failed(trueType) || failed(falseType)) return failure(); if (*trueType == *falseType) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// @@ -728,27 +729,25 @@ /// Return the buffer type for a given Value (tensor) after bufferization. FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options) { - DenseMap fixedTypes; - return getBufferType(value, options, fixedTypes); + SmallVector invocationStack; + return getBufferType(value, options, invocationStack); } /// Return the buffer type for a given Value (tensor) after bufferization. -FailureOr bufferization::getBufferType( - Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) { +FailureOr +bufferization::getBufferType(Value value, const BufferizationOptions &options, + SmallVector &invocationStack) { assert(llvm::isa(value.getType()) && "unexpected non-tensor type"); - - // If the `value` is in `fixedTypes`, return the mapped type. - const auto &it = fixedTypes.find(value); - if (it != fixedTypes.end()) - return it->second; + invocationStack.push_back(value); + auto popFromStack = + llvm::make_scope_exit([&]() { invocationStack.pop_back(); }); // Try querying BufferizableOpInterface. Operation *op = getOwnerOfValue(value); auto bufferizableOp = options.dynCastBufferizableOp(op); if (bufferizableOp) - return bufferizableOp.getBufferType(value, options, fixedTypes); + return bufferizableOp.getBufferType(value, options, invocationStack); // Op is not bufferizable. if (!options.defaultMemorySpace.has_value()) @@ -996,7 +995,7 @@ FailureOr bufferization::detail::defaultGetBufferType( Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) { + SmallVector &invocationStack) { assert(llvm::isa(value.getType()) && "expected tensor type"); // No further analysis is possible for a block argument. @@ -1013,7 +1012,7 @@ // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. Value equivalentOperand = aliases.getAliases().front().opOperand->get(); - return getBufferType(equivalentOperand, options, fixedTypes); + return getBufferType(equivalentOperand, options, invocationStack); } // If we do not know the memory space and there is no default memory space, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -230,9 +230,9 @@ return {}; } -FailureOr AllocTensorOp::getBufferType( - Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) { +FailureOr +AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, + SmallVector &invocationStack) { assert(value == getResult() && "invalid value"); // Compute memory space of this allocation. @@ -241,7 +241,7 @@ memorySpace = *getMemorySpace(); } else if (getCopy()) { auto copyBufferType = - bufferization::getBufferType(getCopy(), options, fixedTypes); + bufferization::getBufferType(getCopy(), options, invocationStack); if (failed(copyBufferType)) return failure(); memorySpace = copyBufferType->getMemorySpace(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -199,7 +199,7 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); @@ -321,7 +321,7 @@ : public BufferizableOpInterface::ExternalModel { FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto funcOp = cast(op); auto bbArg = cast(value); // Unstructured control flow is not supported. diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -211,7 +211,7 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto ifOp = cast(op); auto thenYieldOp = cast(ifOp.thenBlock()->getTerminator()); auto elseYieldOp = cast(ifOp.elseBlock()->getTerminator()); @@ -227,7 +227,7 @@ thenBufferType = cast(thenValue.getType()); } else { auto maybeBufferType = - bufferization::getBufferType(thenValue, options, fixedTypes); + bufferization::getBufferType(thenValue, options, invocationStack); if (failed(maybeBufferType)) return failure(); thenBufferType = *maybeBufferType; @@ -237,7 +237,7 @@ elseBufferType = cast(elseValue.getType()); } else { auto maybeBufferType = - bufferization::getBufferType(elseValue, options, fixedTypes); + bufferization::getBufferType(elseValue, options, invocationStack); if (failed(maybeBufferType)) return failure(); elseBufferType = *maybeBufferType; @@ -331,33 +331,34 @@ /// /// This function uses bufferization::getBufferType to compute the bufferized /// type of the init_arg and of the yielded value. (The computation of the -/// usually requires computing the bufferized type of the corresponding -/// iter_arg; the implementation of getBufferType traces back the use-def chain -/// of the given value and computes a buffer type along the way.) If both buffer -/// types are equal, no casts are needed the computed buffer type can be used -/// directly. Otherwise, the buffer types can only differ in their layout map -/// and a cast must be inserted. +/// bufferized yielded value type usually requires computing the bufferized type +/// of the iter_arg again; the implementation of getBufferType traces back the +/// use-def chain of the given value and computes a buffer type along the way.) +/// If both buffer types are equal, no casts are needed the computed buffer type +/// can be used directly. Otherwise, the buffer types can only differ in their +/// layout map and a cast must be inserted. static FailureOr computeLoopRegionIterArgBufferType( BlockArgument iterArg, Value initArg, Value yieldedValue, - const BufferizationOptions &options, - const DenseMap &fixedTypes) { + const BufferizationOptions &options, SmallVector &invocationStack) { // Determine the buffer type of the init_arg. auto initArgBufferType = - bufferization::getBufferType(initArg, options, fixedTypes); + bufferization::getBufferType(initArg, options, invocationStack); if (failed(initArgBufferType)) return failure(); - // Fix the iter_arg type, so that recursive lookups return the buffer type - // of the init_arg. This is to avoid infinite loops when calculating the - // buffer type of the yielded value. - // - // Note: For more precise layout map computation, a fixpoint iteration could - // be done (i.e., re-computing the yielded buffer type until the bufferized - // iter_arg type no longer changes). This current implementation immediately - // switches to a fully dynamic layout map when a mismatch between bufferized - // init_arg type and bufferized yield value type is detected. - DenseMap newFixedTypes(fixedTypes); - newFixedTypes[iterArg] = *initArgBufferType; + if (llvm::count(invocationStack, iterArg) >= 2) { + // If the iter_arg is already twice on the invocation stack, just take the + // type of the init_arg. This is to avoid infinite loops when calculating + // the buffer type. This will most likely result in computing a memref type + // with a fully dynamic layout map. + + // Note: For more precise layout map computation, a fixpoint iteration could + // be done (i.e., re-computing the yielded buffer type until the bufferized + // iter_arg type no longer changes). This current implementation immediately + // switches to a fully dynamic layout map when a mismatch between bufferized + // init_arg type and bufferized yield value type is detected. + return *initArgBufferType; + } // Compute the buffer type of the yielded value. BaseMemRefType yieldedValueBufferType; @@ -365,8 +366,10 @@ // scf.yield was already bufferized. yieldedValueBufferType = cast(yieldedValue.getType()); } else { + // Note: This typically triggers a recursive call for the buffer type of + // the iter_arg. auto maybeBufferType = - bufferization::getBufferType(yieldedValue, options, newFixedTypes); + bufferization::getBufferType(yieldedValue, options, invocationStack); if (failed(maybeBufferType)) return failure(); yieldedValueBufferType = *maybeBufferType; @@ -376,20 +379,21 @@ if (*initArgBufferType == yieldedValueBufferType) return yieldedValueBufferType; - // If there is a mismatch between the yielded buffer type and the iter_arg + // If there is a mismatch between the yielded buffer type and the init_arg // buffer type, the buffer type must be promoted to a fully dynamic layout // map. auto yieldedRanked = cast(yieldedValueBufferType); + auto iterRanked = cast(iterArg.getType()); #ifndef NDEBUG - auto iterRanked = llvm::cast(*initArgBufferType); - assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) && + auto initRanked = llvm::cast(*initArgBufferType); + assert(llvm::all_equal({yieldedRanked.getShape(), initRanked.getShape(), + iterRanked.getShape()}) && "expected same shape"); - assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() && + assert(yieldedRanked.getMemorySpace() == initRanked.getMemorySpace() && "expected same memory space"); #endif // NDEBUG - return getMemRefTypeWithFullyDynamicLayout( - cast(iterArg.getType()), - yieldedRanked.getMemorySpace()); + return getMemRefTypeWithFullyDynamicLayout(iterRanked, + yieldedRanked.getMemorySpace()); } /// Return `true` if the given loop may have 0 iterations. @@ -513,21 +517,24 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto forOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); assert(isa(value.getType()) && "expected tensor type"); - // Get result/argument number. - unsigned resultNum; - if (auto bbArg = dyn_cast(value)) { - resultNum = - forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg)) - .getResultNumber(); - } else { - resultNum = cast(value).getResultNumber(); + if (auto opResult = dyn_cast(value)) { + // The type of an OpResult must match the corresponding iter_arg type. + BlockArgument bbArg = forOp.getRegionIterArgForOpOperand( + forOp.getOpOperandForResult(opResult)); + return bufferization::getBufferType(bbArg, options, invocationStack); } + // Compute result/argument number. + BlockArgument bbArg = cast(value); + unsigned resultNum = + forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg)) + .getResultNumber(); + // Compute the bufferized type. auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); @@ -535,7 +542,7 @@ BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; Value initArg = forOp.getInitArgs()[resultNum]; return computeLoopRegionIterArgBufferType(iterArg, initArg, yieldedValue, - options, fixedTypes); + options, invocationStack); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -838,7 +845,7 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto whileOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); assert(isa(value.getType()) && "expected tensor type"); @@ -850,7 +857,7 @@ auto yieldOp = whileOp.getYieldOp(); Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber()); return computeLoopRegionIterArgBufferType(bbArg, initArg, yieldedValue, - options, fixedTypes); + options, invocationStack); } } @@ -872,7 +879,7 @@ return cast(conditionYieldedVal.getType()); } return bufferization::getBufferType(conditionYieldedVal, options, - fixedTypes); + invocationStack); } /// Assert that yielded values of an scf.while op are equivalent to their @@ -1104,20 +1111,20 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto forallOp = cast(op); if (auto bbArg = dyn_cast(value)) // A tensor block argument has the same bufferized type as the // corresponding output operand. return bufferization::getBufferType( - forallOp.getTiedOpOperand(bbArg)->get(), options, fixedTypes); + forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack); // The bufferized result type is the same as the bufferized type of the // corresponding output operand. return bufferization::getBufferType( forallOp.getOutputs()[cast(value).getResultNumber()], options, - fixedTypes); + invocationStack); } bool isRepetitiveRegion(Operation *op, unsigned index) const { diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -48,10 +48,10 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto castOp = cast(op); - auto maybeSrcBufferType = - bufferization::getBufferType(castOp.getSource(), options, fixedTypes); + auto maybeSrcBufferType = bufferization::getBufferType( + castOp.getSource(), options, invocationStack); if (failed(maybeSrcBufferType)) return failure(); Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); @@ -133,10 +133,10 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto collapseShapeOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( - collapseShapeOp.getSrc(), options, fixedTypes); + collapseShapeOp.getSrc(), options, invocationStack); if (failed(maybeSrcBufferType)) return failure(); auto srcBufferType = llvm::cast(*maybeSrcBufferType); @@ -302,10 +302,10 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto expandShapeOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( - expandShapeOp.getSrc(), options, fixedTypes); + expandShapeOp.getSrc(), options, invocationStack); if (failed(maybeSrcBufferType)) return failure(); auto srcBufferType = llvm::cast(*maybeSrcBufferType); @@ -383,11 +383,11 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto extractSliceOp = cast(op); assert(value == extractSliceOp.getResult() && "invalid value"); auto srcMemrefType = bufferization::getBufferType( - extractSliceOp.getSource(), options, fixedTypes); + extractSliceOp.getSource(), options, invocationStack); if (failed(srcMemrefType)) return failure(); SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); @@ -853,11 +853,11 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { // Infer memory space from the source tensor. auto padOp = cast(op); - auto maybeSrcBufferType = - bufferization::getBufferType(padOp.getSource(), options, fixedTypes); + auto maybeSrcBufferType = bufferization::getBufferType( + padOp.getSource(), options, invocationStack); if (failed(maybeSrcBufferType)) return failure(); MemRefLayoutAttrInterface layout; @@ -1002,11 +1002,11 @@ FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { + SmallVector &invocationStack) const { auto reshapeOp = cast(op); assert(value == reshapeOp.getResult() && "unexpected value provided"); auto maybeSourceBufferType = bufferization::getBufferType( - reshapeOp.getSource(), options, fixedTypes); + reshapeOp.getSource(), options, invocationStack); if (failed(maybeSourceBufferType)) return failure(); return getMemRefTypeWithStaticIdentityLayout(