diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -503,6 +503,9 @@ Optional customCopyInsertionPoint = None); /// Return the buffer type for a given OpOperand (tensor) after bufferization. + /// + /// Note: Op implementations should preferrably call `getBuffer()->getType()`. + /// This function should only be used if `getBuffer` cannot be used. BaseMemRefType getBufferType(OpOperand &opOperand) const; /// Return a reference to the BufferizationOptions. @@ -546,9 +549,18 @@ return newOp; } -/// Return a MemRefType to which the `tensorType` can be bufferized in a -/// composable fashion. The layout must be the most dynamic possible and -/// canonicalize away once bufferization is finished. +/// Return a MemRefType to which the `tensorType` can be bufferized. +/// +/// If possible, op bufferization implementations should not use this function +/// and instead infer precise memref types for tensor results by themselves. +/// +/// Unless a layout map was specified, `options` flags determine what kind of +/// layout map will be used. For best composability (without copies), the fully +/// dynamic layout map is used by default. +/// +/// Note: Canonicalization patterns could clean up layout maps and infer more +/// precise layout maps after bufferization. However, many possible +/// canonicalizations are currently not implemented. BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout = {}, diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -82,17 +82,22 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto castOp = cast(op); + auto resultTensorType = castOp.getType().cast(); Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); auto sourceType = source.getType().cast(); // Result type should have same layout and address space as the source type. - MemRefLayoutAttrInterface layout = {}; - if (auto rankedMemRefType = sourceType.dyn_cast()) - layout = rankedMemRefType.getLayout(); - Type resultType = - getMemRefType(castOp.getType().cast(), state.getOptions(), - layout, sourceType.getMemorySpace()); + BaseMemRefType resultType; + if (auto rankedMemRefType = sourceType.dyn_cast()) { + resultType = MemRefType::get( + rankedMemRefType.getShape(), resultTensorType.getElementType(), + rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace()); + } else { + auto unrankedMemrefType = sourceType.cast(); + resultType = UnrankedMemRefType::get(resultTensorType.getElementType(), + unrankedMemrefType.getMemorySpace()); + } replaceOpWithNewBufferizedOp(rewriter, op, resultType, source); @@ -146,15 +151,14 @@ // both of them to the most dynamic MemRef type. if (trueBuffer.getType() != falseBuffer.getType()) { auto trueType = trueBuffer.getType().cast(); - auto tensorType = selectOp.getTrueValue().getType().cast(); int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; - SmallVector dynamicStrides(tensorType.getRank(), + SmallVector dynamicStrides(trueType.getRank(), ShapedType::kDynamicStrideOrOffset); AffineMap stridedLayout = makeStridedLinearLayoutMap( dynamicStrides, dynamicOffset, op->getContext()); - BaseMemRefType castedType = bufferization::getMemRefType( - tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout), - trueType.getMemorySpace()); + auto castedType = + MemRefType::get(trueType.getShape(), trueType.getElementType(), + stridedLayout, trueType.getMemorySpaceAsInt()); trueBuffer = rewriter.create(loc, castedType, trueBuffer); falseBuffer = rewriter.create(loc, castedType, falseBuffer); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -79,6 +79,7 @@ SmallVector newResultTypes; for (Type type : executeRegionOp->getResultTypes()) { if (auto tensorType = type.dyn_cast()) { + // TODO: Infer the result type instead of computing it. newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); } else { newResultTypes.push_back(type); @@ -188,6 +189,7 @@ SmallVector newTypes; for (Type returnType : ifOp->getResultTypes()) { if (auto tensorType = returnType.dyn_cast()) { + // TODO: Infer the result type instead of computing it. newTypes.push_back(getMemRefType(tensorType, state.getOptions())); } else { newTypes.push_back(returnType); diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -66,6 +66,7 @@ SmallVector newResultTypes; for (Type type : assumingOp->getResultTypes()) { if (auto tensorType = type.dyn_cast()) { + // TODO: Infer the result type instead of computing it. newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); } else { newResultTypes.push_back(type);