Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// | //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// | ||||
// | // | ||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
// See https://llvm.org/LICENSE.txt for license information. | // See https://llvm.org/LICENSE.txt for license information. | ||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" | #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" | ||||
#include "mlir/Dialect/Arith/IR/Arith.h" | #include "mlir/Dialect/Arith/IR/Arith.h" | ||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | ||||
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" | #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" | ||||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | #include "mlir/Dialect/MemRef/IR/MemRef.h" | ||||
#include "mlir/IR/Attributes.h" | |||||
#include "mlir/IR/Dialect.h" | #include "mlir/IR/Dialect.h" | ||||
#include "mlir/IR/Operation.h" | #include "mlir/IR/Operation.h" | ||||
using namespace mlir; | using namespace mlir; | ||||
using namespace mlir::bufferization; | using namespace mlir::bufferization; | ||||
namespace { | namespace { | ||||
/// Bufferization of arith.constant. Replace with memref.get_global. | /// Bufferization of arith.constant. Replace with memref.get_global. | ||||
struct ConstantOpInterface | struct ConstantOpInterface | ||||
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface, | : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, | ||||
arith::ConstantOp> { | arith::ConstantOp> { | ||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, | LogicalResult bufferize(Operation *op, RewriterBase &rewriter, | ||||
const BufferizationOptions &options) const { | const BufferizationOptions &options) const { | ||||
auto constantOp = cast<arith::ConstantOp>(op); | auto constantOp = cast<arith::ConstantOp>(op); | ||||
// TODO: Implement memory space for this op. E.g., by adding a memory_space | Attribute memorySpace; | ||||
// attribute to ConstantOp. | if (options.defaultMemorySpace.has_value()) | ||||
if (options.defaultMemorySpace != Attribute()) | memorySpace = *options.defaultMemorySpace; | ||||
return op->emitError("memory space not implemented yet"); | else | ||||
return constantOp->emitError("could not infer memory space"); | |||||
// Only ranked tensors are supported. | // Only ranked tensors are supported. | ||||
if (!constantOp.getType().isa<RankedTensorType>()) | if (!constantOp.getType().isa<RankedTensorType>()) | ||||
return failure(); | return failure(); | ||||
// Only constants inside a module are supported. | // Only constants inside a module are supported. | ||||
auto moduleOp = constantOp->getParentOfType<ModuleOp>(); | auto moduleOp = constantOp->getParentOfType<ModuleOp>(); | ||||
if (!moduleOp) | if (!moduleOp) | ||||
return failure(); | return failure(); | ||||
// Create global memory segment and replace tensor with memref pointing to | // Create global memory segment and replace tensor with memref pointing to | ||||
// that memory segment. | // that memory segment. | ||||
FailureOr<memref::GlobalOp> globalOp = | FailureOr<memref::GlobalOp> globalOp = | ||||
getGlobalFor(constantOp, options.bufferAlignment); | getGlobalFor(constantOp, options.bufferAlignment, memorySpace); | ||||
if (failed(globalOp)) | if (failed(globalOp)) | ||||
return failure(); | return failure(); | ||||
memref::GlobalOp globalMemref = *globalOp; | memref::GlobalOp globalMemref = *globalOp; | ||||
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( | replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( | ||||
rewriter, op, globalMemref.getType(), globalMemref.getName()); | rewriter, op, globalMemref.getType(), globalMemref.getName()); | ||||
return success(); | return success(); | ||||
} | } | ||||
▲ Show 20 Lines • Show All 148 Lines • Show Last 20 Lines |