Please use GitHub pull requests for new patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
Show First 20 Lines • Show All 141 Lines • ▼ Show 20 Lines | bool BufferPlacementTransformationBase::isLoop(Operation *op) { | ||||
return false; | return false; | ||||
} | } | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// BufferPlacementTransformationBase | // BufferPlacementTransformationBase | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
FailureOr<memref::GlobalOp> | FailureOr<memref::GlobalOp> | ||||
bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) { | bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, | ||||
Attribute memorySpace) { | |||||
auto type = constantOp.getType().cast<RankedTensorType>(); | auto type = constantOp.getType().cast<RankedTensorType>(); | ||||
auto moduleOp = constantOp->getParentOfType<ModuleOp>(); | auto moduleOp = constantOp->getParentOfType<ModuleOp>(); | ||||
if (!moduleOp) | if (!moduleOp) | ||||
return failure(); | return failure(); | ||||
// If we already have a global for this constant value, no need to do | // If we already have a global for this constant value, no need to do | ||||
// anything else. | // anything else. | ||||
for (Operation &op : moduleOp.getRegion().getOps()) { | for (Operation &op : moduleOp.getRegion().getOps()) { | ||||
Show All 20 Lines | bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, | ||||
os << "x" << type.getElementType(); | os << "x" << type.getElementType(); | ||||
// Add an optional alignment to the global memref. | // Add an optional alignment to the global memref. | ||||
IntegerAttr memrefAlignment = | IntegerAttr memrefAlignment = | ||||
alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) | alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) | ||||
: IntegerAttr(); | : IntegerAttr(); | ||||
BufferizeTypeConverter typeConverter; | BufferizeTypeConverter typeConverter; | ||||
auto memrefType = typeConverter.convertType(type).cast<MemRefType>(); | |||||
if (memorySpace) | |||||
memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); | |||||
auto global = globalBuilder.create<memref::GlobalOp>( | auto global = globalBuilder.create<memref::GlobalOp>( | ||||
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), | constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), | ||||
/*sym_visibility=*/globalBuilder.getStringAttr("private"), | /*sym_visibility=*/globalBuilder.getStringAttr("private"), | ||||
/*type=*/typeConverter.convertType(type).cast<MemRefType>(), | /*type=*/memrefType, | ||||
/*initial_value=*/constantOp.getValue().cast<ElementsAttr>(), | /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(), | ||||
/*constant=*/true, | /*constant=*/true, | ||||
/*alignment=*/memrefAlignment); | /*alignment=*/memrefAlignment); | ||||
symbolTable.insert(global); | symbolTable.insert(global); | ||||
// The symbol table inserts at the end of the module, but globals are a bit | // The symbol table inserts at the end of the module, but globals are a bit | ||||
// nicer if they are at the beginning. | // nicer if they are at the beginning. | ||||
global->moveBefore(&moduleOp.front()); | global->moveBefore(&moduleOp.front()); | ||||
return global; | return global; | ||||
} | } |