Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Show First 20 Lines • Show All 422 Lines • ▼ Show 20 Lines | return failure(parser.parseRegion(*body, regionArgs, dataTypes) || | ||||
parser.parseOptionalAttrDict(result.attributes)); | parser.parseOptionalAttrDict(result.attributes)); | ||||
} | } | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// LaunchFuncOp | // LaunchFuncOp | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, | void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, | ||||
GPUFuncOp kernelFunc, Value gridSizeX, Value gridSizeY, | GPUFuncOp kernelFunc, KernelDim3 gridSize, | ||||
Value gridSizeZ, Value blockSizeX, Value blockSizeY, | KernelDim3 blockSize, ValueRange kernelOperands) { | ||||
Value blockSizeZ, ValueRange kernelOperands) { | |||||
// Add grid and block sizes as op operands, followed by the data operands. | // Add grid and block sizes as op operands, followed by the data operands. | ||||
result.addOperands( | result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x, | ||||
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); | blockSize.y, blockSize.z}); | ||||
result.addOperands(kernelOperands); | result.addOperands(kernelOperands); | ||||
auto kernelModule = kernelFunc.getParentOfType<GPUModuleOp>(); | auto kernelModule = kernelFunc.getParentOfType<GPUModuleOp>(); | ||||
auto kernelSymbol = builder.getSymbolRefAttr( | auto kernelSymbol = builder.getSymbolRefAttr( | ||||
kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())}); | kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())}); | ||||
result.addAttribute(getKernelAttrName(), kernelSymbol); | result.addAttribute(getKernelAttrName(), kernelSymbol); | ||||
} | } | ||||
void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, | |||||
GPUFuncOp kernelFunc, KernelDim3 gridSize, | |||||
KernelDim3 blockSize, ValueRange kernelOperands) { | |||||
build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, | |||||
blockSize.x, blockSize.y, blockSize.z, kernelOperands); | |||||
} | |||||
SymbolRefAttr LaunchFuncOp::kernel() { | |||||
return getAttrOfType<SymbolRefAttr>(getKernelAttrName()); | |||||
} | |||||
unsigned LaunchFuncOp::getNumKernelOperands() { | unsigned LaunchFuncOp::getNumKernelOperands() { | ||||
return getNumOperands() - kNumConfigOperands; | return getNumOperands() - kNumConfigOperands; | ||||
} | } | ||||
StringRef LaunchFuncOp::getKernelModuleName() { | StringRef LaunchFuncOp::getKernelModuleName() { | ||||
return kernel().getRootReference(); | return kernel().getRootReference(); | ||||
} | } | ||||
Show All 24 Lines | static LogicalResult verify(LaunchFuncOp op) { | ||||
auto kernelAttr = op.getAttrOfType<SymbolRefAttr>(op.getKernelAttrName()); | auto kernelAttr = op.getAttrOfType<SymbolRefAttr>(op.getKernelAttrName()); | ||||
if (!kernelAttr) | if (!kernelAttr) | ||||
return op.emitOpError("symbol reference attribute '" + | return op.emitOpError("symbol reference attribute '" + | ||||
op.getKernelAttrName() + "' must be specified"); | op.getKernelAttrName() + "' must be specified"); | ||||
return success(); | return success(); | ||||
} | } | ||||
static ParseResult | |||||
parseLaunchFuncOperands(OpAsmParser &parser, | |||||
SmallVectorImpl<OpAsmParser::OperandType> &argNames, | |||||
SmallVectorImpl<Type> &argTypes) { | |||||
if (parser.parseOptionalKeyword("args")) | |||||
herhut: Nit: This could be in the `assemblyFormat` | |||||
csiggAuthorUnsubmitted The args(...) is optional, and I don't think custom<> inside an optional group is supported. I left it as is. csigg: The `args(...)` is optional, and I don't think `custom<>` inside an optional group is supported. | |||||
return success(); | |||||
SmallVector<NamedAttrList, 4> argAttrs; | |||||
bool isVariadic = false; | |||||
if (impl::parseFunctionArgumentList(parser, /*allowVariadic=*/false, argNames, | |||||
argTypes, argAttrs, isVariadic)) | |||||
return failure(); | |||||
for (auto pair : llvm::enumerate(argAttrs)) { | |||||
if (pair.value().empty()) | |||||
continue; | |||||
return parser.emitError(argNames[pair.index()].location, | |||||
"does not support kernel argument attributes"); | |||||
} | |||||
return success(); | |||||
} | |||||
static void printLaunchFuncOperands(OpAsmPrinter &printer, | |||||
OperandRange operands, TypeRange types) { | |||||
if (operands.empty()) | |||||
I'd prefer to expose the parsing of just the argument list instead of parsing a result type and then failing. herhut: I'd prefer to expose the parsing of just the argument list instead of parsing a result type and… | |||||
return; | |||||
printer << "args("; | |||||
llvm::interleaveComma(llvm::zip(operands, types), printer, | |||||
[&](const auto &pair) { | |||||
printer.printOperand(std::get<0>(pair)); | |||||
printer << " : "; | |||||
printer.printType(std::get<1>(pair)); | |||||
}); | |||||
printer << ")"; | |||||
} | |||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// GPUFuncOp | // GPUFuncOp | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
/// Adds a new block argument that corresponds to buffers located in | /// Adds a new block argument that corresponds to buffers located in | ||||
/// workgroup memory. | /// workgroup memory. | ||||
BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) { | BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type) { | ||||
auto attrName = getNumWorkgroupAttributionsAttrName(); | auto attrName = getNumWorkgroupAttributionsAttrName(); | ||||
▲ Show 20 Lines • Show All 346 Lines • Show Last 20 Lines |
Nit: This could be in the assemblyFormat