diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -430,14 +430,19 @@ let hasVerifier = 1; } -def GPU_LaunchFuncOp : GPU_Op<"launch_func", - [GPU_AsyncOpInterface, AttrSizedOperandSegments]>, +def LaunchIndx : AnyTypeOf<[Index, I32, I64]>; + +def GPU_LaunchFuncOp :GPU_Op<"launch_func", [ + GPU_AsyncOpInterface, AttrSizedOperandSegments, + AllTypesMatch<["gridSizeX", "gridSizeY", "gridSizeZ", "blockSizeX", + "blockSizeY", "blockSizeZ"]>]>, Arguments<(ins Variadic:$asyncDependencies, SymbolRefAttr:$kernel, - Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ, - Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ, + LaunchIndx:$gridSizeX, LaunchIndx:$gridSizeY, LaunchIndx:$gridSizeZ, + LaunchIndx:$blockSizeX, LaunchIndx:$blockSizeY, LaunchIndx:$blockSizeZ, Optional:$dynamicSharedMemorySize, - Variadic:$kernelOperands)>, + Variadic:$kernelOperands, + Optional:$asyncObject)>, Results<(outs Optional:$asyncToken)> { let summary = "Launches a function as a GPU kernel"; @@ -528,7 +533,11 @@ "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize, "ValueRange":$kernelOperands, CArg<"Type", "nullptr">:$asyncTokenType, - CArg<"ValueRange", "{}">:$asyncDependencies)> + CArg<"ValueRange", "{}">:$asyncDependencies)>, + OpBuilder<(ins "SymbolRefAttr":$kernel, "KernelDim3":$gridSize, + "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize, + "ValueRange":$kernelOperands, + CArg<"Value", "nullptr">:$asyncObject)> ]; let extraClassDeclaration = [{ @@ -558,9 +567,11 @@ let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) + (`<` $asyncObject^ type($asyncObject) `>`)? $kernel - `blocks` `in` ` ` `(`$gridSizeX`,` $gridSizeY`,` $gridSizeZ`)` - `threads` `in` ` ` `(`$blockSizeX`,` $blockSizeY`,` $blockSizeZ`)` + `blocks` `in` ` ` `(` $gridSizeX `,` $gridSizeY `,` $gridSizeZ `)` + `threads` `in` ` ` `(` $blockSizeX `,` $blockSizeY `,` $blockSizeZ `)` + custom(type($gridSizeX)) (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)? custom($kernelOperands, type($kernelOperands)) attr-dict }]; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -316,12 +316,22 @@ LaunchFuncOp::getKernelAttrName(launchOp->getName()))) return success(); - // Check that `launch_func` refers to a well-formed GPU kernel module. - StringAttr kernelModuleName = launchOp.getKernelModuleName(); - auto kernelModule = module.lookupSymbol(kernelModuleName); + // Check that `launch_func` refers to a well-formed GPU kernel container. + StringAttr kernelContainerName = launchOp.getKernelModuleName(); + Operation *kernelContainer = module.lookupSymbol(kernelContainerName); + if (!kernelContainer) + return launchOp.emitOpError() + << "kernel container '" << kernelContainerName.getValue() + << "' is undefined"; + + // If the container is a GPU binary op return success. + if (isa(kernelContainer)) + return success(); + + auto kernelModule = dyn_cast(kernelContainer); if (!kernelModule) return launchOp.emitOpError() - << "kernel module '" << kernelModuleName.getValue() + << "kernel module '" << kernelContainerName.getValue() << "' is undefined"; // Check that `launch_func` refers to a well-formed kernel function. @@ -978,13 +988,45 @@ auto kernelSymbol = SymbolRefAttr::get(kernelModule.getNameAttr(), {SymbolRefAttr::get(kernelFunc.getNameAttr())}); - result.addAttribute(getKernelAttrName(result.name), kernelSymbol); - SmallVector segmentSizes(9, 1); - segmentSizes.front() = asyncDependencies.size(); - segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0; - segmentSizes.back() = static_cast(kernelOperands.size()); - result.addAttribute(getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr(segmentSizes)); + + Properties &prop = result.getOrAddProperties(); + prop.kernel = kernelSymbol; + size_t segmentSizesLen = std::size(prop.odsOperandSegmentSizes); + // Initialize the segment sizes to 1. + for (auto &sz : prop.odsOperandSegmentSizes) + sz = 1; + prop.odsOperandSegmentSizes[0] = asyncDependencies.size(); + prop.odsOperandSegmentSizes[segmentSizesLen - 3] = + dynamicSharedMemorySize ? 1 : 0; + prop.odsOperandSegmentSizes[segmentSizesLen - 2] = + static_cast(kernelOperands.size()); + prop.odsOperandSegmentSizes[segmentSizesLen - 1] = 0; +} + +void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, + SymbolRefAttr kernel, KernelDim3 gridSize, + KernelDim3 getBlockSize, Value dynamicSharedMemorySize, + ValueRange kernelOperands, Value asyncObject) { + // Add grid and block sizes as op operands, followed by the data operands. + result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, + getBlockSize.y, getBlockSize.z}); + if (dynamicSharedMemorySize) + result.addOperands(dynamicSharedMemorySize); + result.addOperands(kernelOperands); + if (asyncObject) + result.addOperands(asyncObject); + Properties &prop = result.getOrAddProperties(); + prop.kernel = kernel; + size_t segmentSizesLen = std::size(prop.odsOperandSegmentSizes); + // Initialize the segment sizes to 1. + for (auto &sz : prop.odsOperandSegmentSizes) + sz = 1; + prop.odsOperandSegmentSizes[0] = 0; + prop.odsOperandSegmentSizes[segmentSizesLen - 3] = + dynamicSharedMemorySize ? 1 : 0; + prop.odsOperandSegmentSizes[segmentSizesLen - 2] = + static_cast(kernelOperands.size()); + prop.odsOperandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0; } StringAttr LaunchFuncOp::getKernelModuleName() { @@ -1027,6 +1069,22 @@ return success(); } +static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy) { + if (succeeded(parser.parseOptionalColon())) { + if (parser.parseType(dimTy)) + return failure(); + } else { + dimTy = IndexType::get(parser.getContext()); + } + return success(); +} + +static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, + Type dimTy) { + if (!dimTy.isIndex()) + printer << ": " << dimTy; +} + static ParseResult parseLaunchFuncOperands( OpAsmParser &parser, SmallVectorImpl &argNames, diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -57,7 +57,7 @@ func.func @launch_func_missing_callee_attribute(%sz : index) { // expected-error@+1 {{'gpu.launch_func' op requires attribute 'kernel'}} "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) - {operand_segment_sizes = array} + {operand_segment_sizes = array} : (index, index, index, index, index, index) -> () return } @@ -77,7 +77,7 @@ module attributes {gpu.container_module} { func.func @launch_func_undefined_module(%sz : index) { - // expected-error@+1 {{kernel module 'kernels' is undefined}} + // expected-error@+1 {{kernel container 'kernels' is undefined}} gpu.launch_func @kernels::@kernel_1 blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) return }