diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -161,10 +161,6 @@ ]; let extraClassDeclaration = [{ - /// Adds a workgroup attribution of the MemRef type with the given shape and - /// element type. - Value addWorkgroupAttribution(ArrayRef shape, Type elementType); - /// Returns `true` if the GPU function defined by this Op is a kernel, i.e. /// it is intended to be launched from host. bool isKernel() { @@ -198,25 +194,31 @@ return {begin, end}; } - // Adds a new block argument that corresponds to buffers located in - // workgroup memory. - BlockArgument addWorkgroupAttribution(Type type) { - auto attrName = getNumWorkgroupAttributionsAttrName(); - auto attr = getAttrOfType(attrName); - setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1)); - return getBody().front().insertArgument( - getType().getNumInputs() + attr.getInt(), type); - } + /// Adds a new block argument that corresponds to buffers located in + /// workgroup memory. + BlockArgument addWorkgroupAttribution(Type type); + /// Returns the number of buffers located in the private memory. + unsigned getNumPrivateAttributions() { + return getOperation()->getNumOperands() - getType().getNumInputs() - + getNumWorkgroupAttributions(); + } + /// Returns a list of block arguments that correspond to buffers located in /// the private memory. ArrayRef getPrivateAttributions() { + // Buffers on the private memory always come after buffers on the workgroup + // memory. auto begin = std::next(getBody().front().args_begin(), getType().getNumInputs() + getNumWorkgroupAttributions()); return {begin, getBody().front().args_end()}; } + /// Adds a new block argument that corresponds to buffers located in + /// private memory. + BlockArgument addPrivateAttribution(Type type); + /// Returns the name of the attribute containing the number of buffers /// located in the workgroup memory. static StringRef getNumWorkgroupAttributionsAttrName() { diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -457,22 +457,25 @@ // GPUFuncOp //===----------------------------------------------------------------------===// -/// Adds a workgroup attribution to "op" of the MemRef type with the given shape -/// and element type. -Value GPUFuncOp::addWorkgroupAttribution(ArrayRef shape, - Type elementType) { - unsigned pos = getNumFuncArguments() + getNumWorkgroupAttributions(); - Block &bodyBlock = body().front(); - Value attribution = bodyBlock.insertArgument( - std::next(bodyBlock.args_begin(), pos), - MemRefType::get(shape, elementType, /*affineMapComposition=*/{}, - GPUDialect::getWorkgroupAddressSpace())); - auto numWorkgroupBuffersAttr = - getAttrOfType(getNumWorkgroupAttributionsAttrName()); - setAttr(getNumWorkgroupAttributionsAttrName(), - IntegerAttr::get(numWorkgroupBuffersAttr.getType(), - numWorkgroupBuffersAttr.getValue() + 1)); - return attribution; +/// Adds a new block argument that corresponds to buffers located in +/// workgroup memory. +BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) { + auto attrName = getNumWorkgroupAttributionsAttrName(); + auto attr = getAttrOfType(attrName); + setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1)); + return getBody().front().insertArgument( + getType().getNumInputs() + attr.getInt(), type); +} + +/// Adds a new block argument that corresponds to buffers located in +/// private memory. +BlockArgument GPUFuncOp::addPrivateAttribution(Type type) { + // Buffers on the private memory always come after buffers on the workgroup + // memory. + auto workgroupAttrCount = getNumWorkgroupAttributions(); + auto privateAttrCount = getNumPrivateAttributions(); + return getBody().front().insertArgument( + getType().getNumInputs() + workgroupAttrCount + privateAttrCount, type); } void GPUFuncOp::build(OpBuilder &builder, OperationState &result, @@ -675,10 +678,13 @@ LogicalResult GPUFuncOp::verifyBody() { unsigned numFuncArguments = getNumArguments(); unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); + unsigned numPrivateAttributions = getNumPrivateAttributions(); unsigned numBlockArguments = front().getNumArguments(); - if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) + if (numBlockArguments < + numFuncArguments + numWorkgroupAttributions + numPrivateAttributions) return emitOpError() << "expected at least " - << numFuncArguments + numWorkgroupAttributions + << numFuncArguments + numWorkgroupAttributions + + numPrivateAttributions << " arguments to body region"; ArrayRef funcArgTypes = getType().getInputs(); diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -151,7 +151,8 @@ /// Adds type to funcOp's workgroup attributions. Value createWorkgroupBuffer() { - int workgroupMemoryAddressSpace = 3; + int workgroupMemoryAddressSpace = + gpu::GPUDialect::getWorkgroupAddressSpace(); auto bufferType = MemRefType::get({kSubgroupSize}, valueType, ArrayRef{}, workgroupMemoryAddressSpace); diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp --- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -160,8 +160,12 @@ auto type = value.getType().dyn_cast(); assert(type && type.hasStaticShape() && "can only promote memrefs"); - Value attribution = - op.addWorkgroupAttribution(type.getShape(), type.getElementType()); + // Get the type of the buffer in the workgroup memory. + int workgroupMemoryAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); + auto bufferType = MemRefType::get(type.getShape(), type.getElementType(), {}, + workgroupMemoryAddressSpace); + + Value attribution = op.addWorkgroupAttribution(bufferType); // Replace the uses first since only the original uses are currently present. // Then insert the copies.