diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -495,11 +495,6 @@ return newOp; } -/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) -/// with the same shape as `shapedType` and specified `addressSpace`. -MemRefType getContiguousMemRefType(ShapedType shapedType, - Attribute memorySpace = {}); - /// Return a MemRefType to which the `tensorType` can be bufferized in a /// composable fashion. The layout must be the most dynamic possible and /// canonicalize away once bufferization is finished. diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -445,6 +445,13 @@ return success(); } +static MemRefType getContiguousMemRefType(ShapedType shapedType, + Attribute memorySpace = {}) { + MemRefLayoutAttrInterface layout = {}; + return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), + layout, memorySpace); +} + /// Compute the type of the `memref` to use for allocating the buffer for /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the /// dynamic dimensions in the returned `memref` type. @@ -649,13 +656,6 @@ return isa(bbArg.getOwner()->getParentOp()); } -MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType, - Attribute memorySpace) { - MemRefLayoutAttrInterface layout = {}; - return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), - layout, memorySpace); -} - BaseMemRefType bufferization::getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout, diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -432,13 +432,12 @@ // Allocate memory. Location loc = op->getLoc(); - MemRefType memrefType = - getContiguousMemRefType(generateOp.getType().cast()); FailureOr maybeResult = state.createAlloc(rewriter, loc, generateOp.result()); if (failed(maybeResult)) return failure(); Value result = *maybeResult; + MemRefType memrefType = result.getType().cast(); // Collect loop bounds. int64_t rank = memrefType.getRank();