diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -431,14 +431,19 @@ let hasVerifier = 1; } -def GPU_LaunchFuncOp : GPU_Op<"launch_func", - [GPU_AsyncOpInterface, AttrSizedOperandSegments]>, +def LaunchIndx : AnyTypeOf<[Index, I32, I64]>; + +def GPU_LaunchFuncOp :GPU_Op<"launch_func", [ + GPU_AsyncOpInterface, AttrSizedOperandSegments, + AllTypesMatch<["gridSizeX", "gridSizeY", "gridSizeZ", "blockSizeX", + "blockSizeY", "blockSizeZ"]>]>, Arguments<(ins Variadic:$asyncDependencies, SymbolRefAttr:$kernel, - Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ, - Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ, + LaunchIndx:$gridSizeX, LaunchIndx:$gridSizeY, LaunchIndx:$gridSizeZ, + LaunchIndx:$blockSizeX, LaunchIndx:$blockSizeY, LaunchIndx:$blockSizeZ, Optional:$dynamicSharedMemorySize, - Variadic:$kernelOperands)>, + Variadic:$kernelOperands, + Optional:$asyncObject)>, Results<(outs Optional:$asyncToken)> { let summary = "Launches a function as a GPU kernel"; @@ -529,7 +534,11 @@ "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize, "ValueRange":$kernelOperands, CArg<"Type", "nullptr">:$asyncTokenType, - CArg<"ValueRange", "{}">:$asyncDependencies)> + CArg<"ValueRange", "{}">:$asyncDependencies)>, + OpBuilder<(ins "SymbolRefAttr":$kernel, "KernelDim3":$gridSize, + "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize, + "ValueRange":$kernelOperands, + CArg<"Value", "nullptr">:$asyncObject)> ]; let extraClassDeclaration = [{ @@ -559,9 +568,11 @@ let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) + (`<` $asyncObject^ `:` type($asyncObject) `>`)? $kernel - `blocks` `in` ` ` `(`$gridSizeX`,` $gridSizeY`,` $gridSizeZ`)` - `threads` `in` ` ` `(`$blockSizeX`,` $blockSizeY`,` $blockSizeZ`)` + `blocks` `in` ` ` `(` $gridSizeX `,` $gridSizeY `,` $gridSizeZ `)` + `threads` `in` ` ` `(` $blockSizeX `,` $blockSizeY `,` $blockSizeZ `)` + custom(type($gridSizeX)) (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)? custom($kernelOperands, type($kernelOperands)) attr-dict }]; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -318,12 +318,22 @@ LaunchFuncOp::getKernelAttrName(launchOp->getName()))) return success(); - // Check that `launch_func` refers to a well-formed GPU kernel module. - StringAttr kernelModuleName = launchOp.getKernelModuleName(); - auto kernelModule = module.lookupSymbol(kernelModuleName); + // Check that `launch_func` refers to a well-formed GPU kernel container. + StringAttr kernelContainerName = launchOp.getKernelModuleName(); + Operation *kernelContainer = module.lookupSymbol(kernelContainerName); + if (!kernelContainer) + return launchOp.emitOpError() + << "kernel container '" << kernelContainerName.getValue() + << "' is undefined"; + + // If the container is a GPU binary op return success. + if (isa(kernelContainer)) + return success(); + + auto kernelModule = dyn_cast(kernelContainer); if (!kernelModule) return launchOp.emitOpError() - << "kernel module '" << kernelModuleName.getValue() + << "kernel module '" << kernelContainerName.getValue() << "' is undefined"; // Check that `launch_func` refers to a well-formed kernel function. @@ -980,13 +990,45 @@ auto kernelSymbol = SymbolRefAttr::get(kernelModule.getNameAttr(), {SymbolRefAttr::get(kernelFunc.getNameAttr())}); - result.addAttribute(getKernelAttrName(result.name), kernelSymbol); - SmallVector segmentSizes(9, 1); - segmentSizes.front() = asyncDependencies.size(); - segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0; - segmentSizes.back() = static_cast(kernelOperands.size()); - result.addAttribute(getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr(segmentSizes)); + + Properties &prop = result.getOrAddProperties(); + prop.kernel = kernelSymbol; + size_t segmentSizesLen = std::size(prop.odsOperandSegmentSizes); + // Initialize the segment sizes to 1. + for (auto &sz : prop.odsOperandSegmentSizes) + sz = 1; + prop.odsOperandSegmentSizes[0] = asyncDependencies.size(); + prop.odsOperandSegmentSizes[segmentSizesLen - 3] = + dynamicSharedMemorySize ? 1 : 0; + prop.odsOperandSegmentSizes[segmentSizesLen - 2] = + static_cast(kernelOperands.size()); + prop.odsOperandSegmentSizes[segmentSizesLen - 1] = 0; +} + +void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, + SymbolRefAttr kernel, KernelDim3 gridSize, + KernelDim3 getBlockSize, Value dynamicSharedMemorySize, + ValueRange kernelOperands, Value asyncObject) { + // Add grid and block sizes as op operands, followed by the data operands. + result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, + getBlockSize.y, getBlockSize.z}); + if (dynamicSharedMemorySize) + result.addOperands(dynamicSharedMemorySize); + result.addOperands(kernelOperands); + if (asyncObject) + result.addOperands(asyncObject); + Properties &prop = result.getOrAddProperties(); + prop.kernel = kernel; + size_t segmentSizesLen = std::size(prop.odsOperandSegmentSizes); + // Initialize the segment sizes to 1. + for (auto &sz : prop.odsOperandSegmentSizes) + sz = 1; + prop.odsOperandSegmentSizes[0] = 0; + prop.odsOperandSegmentSizes[segmentSizesLen - 3] = + dynamicSharedMemorySize ? 1 : 0; + prop.odsOperandSegmentSizes[segmentSizesLen - 2] = + static_cast(kernelOperands.size()); + prop.odsOperandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0; } StringAttr LaunchFuncOp::getKernelModuleName() { @@ -1029,6 +1071,22 @@ return success(); } +static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy) { + if (succeeded(parser.parseOptionalColon())) { + if (parser.parseType(dimTy)) + return failure(); + } else { + dimTy = IndexType::get(parser.getContext()); + } + return success(); +} + +static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, + Type dimTy) { + if (!dimTy.isIndex()) + printer << ": " << dimTy; +} + static ParseResult parseLaunchFuncOperands( OpAsmParser &parser, SmallVectorImpl &argNames, diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -57,7 +57,7 @@ func.func @launch_func_missing_callee_attribute(%sz : index) { // expected-error@+1 {{'gpu.launch_func' op requires attribute 'kernel'}} "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) - {operand_segment_sizes = array} + {operand_segment_sizes = array} : (index, index, index, index, index, index) -> () return } @@ -77,7 +77,7 @@ module attributes {gpu.container_module} { func.func @launch_func_undefined_module(%sz : index) { - // expected-error@+1 {{kernel module 'kernels' is undefined}} + // expected-error@+1 {{kernel container 'kernels' is undefined}} gpu.launch_func @kernels::@kernel_1 blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) return } diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -128,8 +128,10 @@ %1 = "op"() : () -> (memref) // CHECK: %{{.*}} = arith.constant 8 %cst = arith.constant 8 : index + %cstI64 = arith.constant 8 : i64 %c0 = arith.constant 0 : i32 %t0 = gpu.wait async + %lowStream = llvm.mlir.null : !llvm.ptr // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref) gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref) @@ -142,6 +144,12 @@ // CHECK: %{{.*}} = gpu.launch_func async [%{{.*}}] @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) %t1 = gpu.launch_func async [%t0] @kernels::@kernel_2 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) + // CHECK: gpu.launch_func <%{{.*}} : !llvm.ptr> @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 args(%{{.*}} : f32, %{{.*}} : memref) + gpu.launch_func <%lowStream : !llvm.ptr> @kernels::@kernel_1 blocks in (%cstI64, %cstI64, %cstI64) threads in (%cstI64, %cstI64, %cstI64) : i64 args(%0 : f32, %1 : memref) + + // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i32 args(%{{.*}} : f32, %{{.*}} : memref) + gpu.launch_func @kernels::@kernel_1 blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0) : i32 args(%0 : f32, %1 : memref) + // CHECK: %[[VALUES:.*]]:2 = call %values:2 = func.call @two_value_generator() : () -> (f32, memref) // CHECK: gpu.launch_func @kernels::@kernel_1 {{.*}} args(%[[VALUES]]#0 : f32, %[[VALUES]]#1 : memref)