diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -165,6 +165,10 @@ /// element type. Value addWorkgroupAttribution(ArrayRef shape, Type elementType); + /// Adds a private attribution of the MemRef type with the given shape and + /// element type. + Value addPrivateAttribution(ArrayRef shape, Type elementType); + /// Returns `true` if the GPU function defined by this Op is a kernel, i.e. /// it is intended to be launched from host. bool isKernel() { @@ -189,6 +193,12 @@ .getInt(); } + /// Returns the number of buffers located in the private memory. + unsigned getNumPrivateAttributions() { + return getAttrOfType(getNumPrivateAttributionsAttrName()) + .getInt(); + } + /// Returns a list of block arguments that correspond to buffers located in /// the workgroup memory ArrayRef getWorkgroupAttributions() { @@ -214,7 +224,22 @@ auto begin = std::next(getBody().front().args_begin(), getType().getNumInputs() + getNumWorkgroupAttributions()); - return {begin, getBody().front().args_end()}; + auto end = std::next(begin, getNumPrivateAttributions()); + return {begin, end}; + } + + // Adds a new block argument that corresponds to buffers located in + // private memory. + BlockArgument addPrivateAttribution(Type type) { + auto attrName = getNumPrivateAttributionsAttrName(); + auto attr = getAttrOfType(attrName); + setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1)); + + auto workgroupAttrName = getNumWorkgroupAttributionsAttrName(); + auto workgroupAttr = getAttrOfType(workgroupAttrName); + + return getBody().front().insertArgument( + getType().getNumInputs() + workgroupAttr.getInt() + attr.getInt(), type); } /// Returns the name of the attribute containing the number of buffers @@ -223,6 +248,12 @@ return "workgroup_attributions"; } + /// Returns the name of the attribute containing the number of buffers + /// located in the private memory. + static StringRef getNumPrivateAttributionsAttrName() { + return "private_attributions"; + } + // FunctionLike trait needs access to the functions below. friend class OpTrait::FunctionLike; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -457,24 +457,46 @@ // GPUFuncOp //===----------------------------------------------------------------------===// -/// Adds a workgroup attribution to "op" of the MemRef type with the given shape -/// and element type. -Value GPUFuncOp::addWorkgroupAttribution(ArrayRef shape, - Type elementType) { - unsigned pos = getNumFuncArguments() + getNumWorkgroupAttributions(); - Block &bodyBlock = body().front(); +/// Adds an attribution to "op" of the MemRef type with the given shape and +/// element type. +static Value addAttribution(GPUFuncOp op, ArrayRef shape, + Type elementType, unsigned pos, unsigned addrspace, + StringRef attributionsAttrName) { + Block &bodyBlock = op.body().front(); Value attribution = bodyBlock.insertArgument( std::next(bodyBlock.args_begin(), pos), MemRefType::get(shape, elementType, /*affineMapComposition=*/{}, - GPUDialect::getWorkgroupAddressSpace())); - auto numWorkgroupBuffersAttr = - getAttrOfType(getNumWorkgroupAttributionsAttrName()); - setAttr(getNumWorkgroupAttributionsAttrName(), - IntegerAttr::get(numWorkgroupBuffersAttr.getType(), - numWorkgroupBuffersAttr.getValue() + 1)); + addrspace)); + auto numBuffersAttr = op.getAttrOfType(attributionsAttrName); + op.setAttr(attributionsAttrName, + IntegerAttr::get(numBuffersAttr.getType(), + numBuffersAttr.getValue() + 1)); return attribution; } +/// Adds a workgroup attribution to "op" of the MemRef type with the given shape +/// and element type. +Value GPUFuncOp::addWorkgroupAttribution(ArrayRef shape, + Type elementType) { + return addAttribution(*this, shape, elementType, + getNumFuncArguments() + getNumWorkgroupAttributions(), + GPUDialect::getWorkgroupAddressSpace(), + getNumWorkgroupAttributionsAttrName()); +} + +/// Adds a private attribution to "op" of the MemRef type with the given shape +/// and element type. +/// +/// Note private attributions are always after workgroup attributions. +Value GPUFuncOp::addPrivateAttribution(ArrayRef shape, + Type elementType) { + return addAttribution(*this, shape, elementType, + getNumFuncArguments() + getNumWorkgroupAttributions() + + getNumPrivateAttributions(), + GPUDialect::getPrivateAddressSpace(), + getNumPrivateAttributionsAttrName()); +} + void GPUFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name, FunctionType type, ArrayRef workgroupAttributions, @@ -485,6 +507,8 @@ result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); result.addAttribute(getNumWorkgroupAttributionsAttrName(), builder.getI64IntegerAttr(workgroupAttributions.size())); + result.addAttribute(getNumPrivateAttributionsAttrName(), + builder.getI64IntegerAttr(privateAttributions.size())); result.addAttributes(attrs); Region *body = result.addRegion(); Block *entryBlock = new Block; @@ -582,6 +606,13 @@ entryArgs, argTypes))) return failure(); + // Store the number of operands we just parsed as the number of private + // memory attributions. + unsigned numPrivateAttrs = + argTypes.size() - type.getNumInputs() - numWorkgroupAttrs; + result.addAttribute(GPUFuncOp::getNumPrivateAttributionsAttrName(), + builder.getI64IntegerAttr(numPrivateAttrs)); + // Parse the kernel attribute if present. if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) result.addAttribute(GPUDialect::getKernelFuncAttrName(), @@ -626,6 +657,7 @@ impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(), type.getNumResults(), {op.getNumWorkgroupAttributionsAttrName(), + op.getNumPrivateAttributionsAttrName(), GPUDialect::getKernelFuncAttrName()}); p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); } @@ -675,10 +707,13 @@ LogicalResult GPUFuncOp::verifyBody() { unsigned numFuncArguments = getNumArguments(); unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); + unsigned numPrivateAttributions = getNumPrivateAttributions(); unsigned numBlockArguments = front().getNumArguments(); - if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) + if (numBlockArguments < + numFuncArguments + numWorkgroupAttributions + numPrivateAttributions) return emitOpError() << "expected at least " - << numFuncArguments + numWorkgroupAttributions + << numFuncArguments + numWorkgroupAttributions + + numPrivateAttributions << " arguments to body region"; ArrayRef funcArgTypes = getType().getInputs();