diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -484,10 +484,14 @@ FailureOr getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options); -/// Return the buffer type for a given Value (tensor) after bufferization. +/// Return the buffer type for a given Value (tensor) after bufferization +/// without bufferizing any IR. /// -/// Note: Op implementations should preferrably call `getBuffer()->getType()`. -/// This function should only be used if `getBuffer` cannot be used. +/// Note: It should be sufficient to call `getBuffer()->getType()` in most +/// cases. However, when a buffer type should be predicted without modifying any +/// IR, this function can be used. +/// +/// This function is a wrapper around BufferizableOpInterface::getBufferType. FailureOr getBufferType(Value value, const BufferizationOptions &options); @@ -538,6 +542,18 @@ BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, unsigned memorySpace = 0); +/// Return the owner of the given value. In case of a BlockArgument that is the +/// owner of the block. In case of an OpResult that is the defining op. +Operation *getOwnerOfValue(Value value); + +namespace detail { +/// This is the default implementation of +/// BufferizableOpInterface::getBufferType. Should not be called from other +/// places. +FailureOr +defaultGetBufferType(Value value, const BufferizationOptions &options); +} // namespace detail + } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -340,39 +340,22 @@ >, InterfaceMethod< /*desc=*/[{ - Return the bufferized type of the given tensor block argument. The - block argument is guaranteed to belong to a block of this op. + Return the bufferized type of the given tensor value (without + bufferizing the IR). The value is either a BlockArgument of a block + that belongs to this op or an OpResult of the given op. + + This method is useful when the bufferized type of value must be + predicted before modifying any IR. }], /*retType=*/"FailureOr", /*methodName=*/"getBufferType", - /*args=*/(ins "BlockArgument":$bbArg, + /*args=*/(ins "Value":$value, "const BufferizationOptions &":$options), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(bbArg.getOwner()->getParentOp() == $_op && - "bbArg must belong to this op"); - assert(bbArg.getType().isa() && - "expected tensor type"); - return bufferization::getMemRefType(bbArg, options); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the memory space of the given tensor OpResult if specified on - this op. If not specified, return `failure`. - - This method will never be called with OpResults that do not bufferize - to a memory allocation. - }], - /*retType=*/"FailureOr", - /*methodName=*/"getMemorySpace", - /*args=*/(ins "OpResult":$opResult), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(cast($_op.getOperation()) - .bufferizesToAllocation(opResult) - && "expected allocation"); - return failure(); + assert(getOwnerOfValue(value) == $_op.getOperation() && + "expected that value belongs to this op"); + return bufferization::detail::defaultGetBufferType(value, options); }] >, ]; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -82,12 +82,6 @@ bool bufferizesToAllocation(OpResult opResult) { return true; } - FailureOr getMemorySpace(OpResult opResult) { - if (getMemorySpace().has_value()) - return static_cast(*getMemorySpace()); - return failure(); - } - bool bufferizesToMemoryRead(OpOperand &opOperand, const AnalysisState &state); @@ -97,6 +91,9 @@ SmallVector getAliasingOpResult( OpOperand &opOperand, const AnalysisState &state); + FailureOr getBufferType( + Value value, const BufferizationOptions &options); + RankedTensorType getType() { return getResult().getType().cast(); } @@ -324,6 +321,11 @@ // It is unknown whether the memref operand is writable or not. return false; } + + FailureOr getBufferType( + Value value, const BufferizationOptions &options) { + return getMemref().getType().cast(); + } }]; let assemblyFormat = "$memref attr-dict `:` type($memref)"; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -38,8 +38,7 @@ using namespace mlir; using namespace bufferization; -/// Return the owner of the given value. -static Operation *getOwnerOfValue(Value value) { +Operation *bufferization::getOwnerOfValue(Value value) { if (auto opResult = value.dyn_cast()) return opResult.getDefiningOp(); return value.cast().getOwner()->getParentOp(); @@ -568,47 +567,53 @@ .getResult(); } +FailureOr bufferization::detail::defaultGetBufferType( + Value value, const BufferizationOptions &options) { + assert(value.getType().isa() && "expected tensor type"); + + // No further analysis is possible for a block argument. + if (value.isa()) + return bufferization::getMemRefType(value, options); + + // Value is an OpResult. + Operation *op = getOwnerOfValue(value); + auto opResult = value.cast(); + auto bufferizableOp = cast(op); + AnalysisState state(options); + auto aliasingOperands = bufferizableOp.getAliasingOpOperand(opResult, state); + if (!aliasingOperands.empty() && + bufferizableOp.bufferRelation(opResult, state) == + BufferRelation::Equivalent) { + // If the OpResult has an equivalent OpOperand, both OpResult and + // OpOperand bufferize to the exact same buffer type. + Value equivalentOperand = aliasingOperands.front()->get(); + return getBufferType(equivalentOperand, options); + } + + // If we do not know the memory space and there is no default memory space, + // report a failure. + if (!options.defaultMemorySpace.has_value()) + return op->emitError("could not infer memory space"); + + return getMemRefType(value, options, /*layout=*/{}, + *options.defaultMemorySpace); +} + /// Return the buffer type for a given Value (tensor) after bufferization. FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options) { assert(value.getType().isa() && "unexpected non-tensor type"); Operation *op = getOwnerOfValue(value); + auto bufferizableOp = options.dynCastBufferizableOp(op); + if (bufferizableOp) + return bufferizableOp.getBufferType(value, options); - // ToTensorOp: Take buffer type directly from the op. - if (auto toTensorOp = value.getDefiningOp()) - return toTensorOp.getMemref().getType().cast(); - - // If value is a bbArg of a bufferizable op: query op interface. - if (auto bbArg = value.dyn_cast()) - if (auto bufferizableOp = - options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) - return bufferizableOp.getBufferType(bbArg, options); - - // Check value is a new buffer allocation with a memory space attribute. In - // that case we can at least infer the memory space. - Optional memorySpace; - if (auto opResult = value.dyn_cast()) { - if (auto bufferizableOp = - options.dynCastBufferizableOp(opResult.getDefiningOp())) { - if (bufferizableOp.bufferizesToAllocation(opResult)) { - FailureOr queriedMemorySpace = - bufferizableOp.getMemorySpace(opResult); - if (!failed(queriedMemorySpace)) - memorySpace = *queriedMemorySpace; - } - } - } - - // If we still do not know the memory space, use the default memory space (if - // any). - if (!memorySpace.has_value()) - memorySpace = options.defaultMemorySpace; - - // If we still do not know the memory space, report a failure. - if (!memorySpace.has_value()) + // Op is not bufferizable. + if (!options.defaultMemorySpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, *memorySpace); + return getMemRefType(value, options, /*layout=*/{}, + *options.defaultMemorySpace); } void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -152,7 +152,6 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, const BufferizationOptions &options) { OpBuilder::InsertionGuard g(rewriter); - Operation *op = this->getOperation(); Location loc = getLoc(); // Nothing to do for dead AllocTensorOps. @@ -170,30 +169,17 @@ copyBuffer = *maybeCopyBuffer; } - // Compute memory space of this allocation. - unsigned memorySpace; - if (getMemorySpace().has_value()) { - memorySpace = *getMemorySpace(); - } else if (getCopy()) { - memorySpace = - copyBuffer.getType().cast().getMemorySpaceAsInt(); - } else if (options.defaultMemorySpace.has_value()) { - memorySpace = *options.defaultMemorySpace; - } else { - return op->emitError("could not infer memory space"); - } - // Create memory allocation. - auto allocType = - MemRefType::get(getType().getShape(), getType().getElementType(), - AffineMap(), memorySpace); + auto allocType = getBufferType(getResult(), options); + if (failed(allocType)) + return failure(); SmallVector dynamicDims = getDynamicSizes(); if (getCopy()) { assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); } - FailureOr alloc = - options.createAlloc(rewriter, loc, allocType, dynamicDims); + FailureOr alloc = options.createAlloc( + rewriter, loc, allocType->cast(), dynamicDims); if (failed(alloc)) return failure(); @@ -247,6 +233,28 @@ return {}; } +FailureOr +AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options) { + assert(value == getResult() && "invalid value"); + + // Compute memory space of this allocation. + unsigned memorySpace; + if (getMemorySpace().has_value()) { + memorySpace = *getMemorySpace(); + } else if (getCopy()) { + auto copyBufferType = bufferization::getBufferType(getCopy(), options); + if (failed(copyBufferType)) + return failure(); + memorySpace = copyBufferType->getMemorySpaceAsInt(); + } else if (options.defaultMemorySpace.has_value()) { + memorySpace = *options.defaultMemorySpace; + } else { + return getOperation()->emitError("could not infer memory space"); + } + + return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace); +} + LogicalResult AllocTensorOp::verify() { if (getCopy() && !getDynamicSizes().empty()) return emitError("dynamic sizes not needed when copying a tensor"); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -472,9 +472,13 @@ } FailureOr - getBufferType(Operation *op, BlockArgument bbArg, + getBufferType(Operation *op, Value value, const BufferizationOptions &options) const { auto forOp = cast(op); + // TODO: Only block arguments supported at the moment. + if (value.isa()) + return failure(); + auto bbArg = value.cast(); return bufferization::getBufferType( forOp.getOpOperandForRegionIterArg(bbArg).get(), options); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -290,21 +290,38 @@ getBuffer(rewriter, extractSliceOp.getSource(), options); if (failed(srcMemref)) return failure(); - auto srcMemrefType = srcMemref->getType().cast(); // Take a subview of the source buffer. - auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( - extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets, - mixedSizes, mixedStrides) - .cast(); + auto resultMemrefType = + getBufferType(op, extractSliceOp.getResult(), options); + if (failed(resultMemrefType)) + return failure(); Value subView = rewriter.create( - loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, - mixedStrides); + loc, resultMemrefType->cast(), *srcMemref, mixedOffsets, + mixedSizes, mixedStrides); replaceOpWithBufferizedValues(rewriter, op, subView); return success(); } + + FailureOr + getBufferType(Operation *op, Value value, + const BufferizationOptions &options) const { + auto extractSliceOp = cast(op); + assert(value == extractSliceOp.getResult() && "invalid value"); + auto srcMemrefType = + bufferization::getBufferType(extractSliceOp.getSource(), options); + if (failed(srcMemrefType)) + return failure(); + SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); + SmallVector mixedSizes = extractSliceOp.getMixedSizes(); + SmallVector mixedStrides = extractSliceOp.getMixedStrides(); + return memref::SubViewOp::inferRankReducedResultType( + extractSliceOp.getType().getShape(), + srcMemrefType->cast(), mixedOffsets, mixedSizes, + mixedStrides) + .cast(); + } }; /// Bufferization of tensor.extract. Replace with memref.load.