diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -472,8 +472,24 @@ let verifier = [{ return ::verify(*this); }]; } -def GPU_ReturnOp : GPU_Op<"return", [Terminator]>, Arguments<(ins)>, - Results<(outs)> { +def GPU_ReturnOp : GPU_Op<"return", [HasParent<"GPUFuncOp">, Terminator]>, + Arguments<(ins Variadic:$operands)>, Results<(outs)> { + let summary = "Terminator for GPU functions."; + let description = [{ + A terminator operation for regions that appear in the body of `gpu.func` + functions. The operands to the `gpu.return` are the result values returned + by an incovation of the `gpu.func`. + }]; + + let builders = [OpBuilder<"Builder *builder, OperationState &result", " // empty">]; + + let parser = [{ return parseReturnOp(parser, result); }]; + let printer = [{ p << getOperationName(); }]; + let verifier = [{ return ::verify(*this); }]; +} + +def GPU_TerminatorOp : GPU_Op<"terminator", [HasParent<"LaunchOp">, Terminator]>, + Arguments<(ins)>, Results<(outs)> { let summary = "Terminator for GPU launch regions."; let description = [{ A terminator operation for regions that appear in the body of `gpu.launch` diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -306,9 +306,9 @@ unsigned numBlockDims, unsigned numThreadDims) { OpBuilder::InsertionGuard bodyInsertionGuard(builder); builder.setInsertionPointToEnd(&launchOp.body().front()); - auto returnOp = builder.create(launchOp.getLoc()); + auto terminatorOp = builder.create(launchOp.getLoc()); - rootForOp.getOperation()->moveBefore(returnOp); + rootForOp.getOperation()->moveBefore(terminatorOp); SmallVector workgroupID, numWorkGroups; packIdAndNumId(launchOp.getBlockIds(), launchOp.getGridSize(), numBlockDims, workgroupID, numWorkGroups); @@ -435,7 +435,7 @@ Location terminatorLoc = terminator.getLoc(); terminator.erase(); builder.setInsertionPointToEnd(innermostForOp.getBody()); - builder.create(terminatorLoc); + builder.create(terminatorLoc, llvm::None); launchOp.body().front().getOperations().splice( launchOp.body().front().begin(), innermostForOp.getBody()->getOperations()); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -270,18 +270,19 @@ } // Block terminators without successors are expected to exit the kernel region - // and must be `gpu.launch`. + // and must be `gpu.terminator`. for (Block &block : op.body()) { if (block.empty()) continue; if (block.back().getNumSuccessors() != 0) continue; - if (!isa(&block.back())) { + if (!isa(&block.back())) { return block.back() - .emitError("expected 'gpu.terminator' or a terminator with " - "successors") - .attachNote(op.getLoc()) - << "in '" << LaunchOp::getOperationName() << "' body region"; + .emitError() + .append("expected '", gpu::TerminatorOp::getOperationName(), + "' or a terminator with successors") + .attachNote(op.getLoc()) + .append("in '", LaunchOp::getOperationName(), "' body region"); } } @@ -680,7 +681,7 @@ << "gpu.func requires named arguments"; // Construct the function type. More types will be added to the region, but - // not to the functiont type. + // not to the function type. Builder &builder = parser.getBuilder(); auto type = builder.getFunctionType(argTypes, resultTypes); result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); @@ -767,6 +768,10 @@ if (!type.isa()) return emitOpError("requires '" + getTypeAttrName() + "' attribute of function type"); + + if (isKernel() && getType().getNumResults() != 0) + return emitOpError() << "expected void return type for kernel function"; + return success(); } @@ -814,6 +819,45 @@ return success(); } +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { + llvm::SmallVector operands; + llvm::SmallVector types; + if (parser.parseOperandList(operands) || + parser.parseOptionalColonTypeList(types) || + parser.resolveOperands(operands, types, parser.getCurrentLocation(), + result.operands)) + return failure(); + + return success(); +} + +static LogicalResult verify(gpu::ReturnOp returnOp) { + GPUFuncOp function = returnOp.getParentOfType(); + + FunctionType funType = function.getType(); + + if (funType.getNumResults() != returnOp.operands().size()) + return returnOp.emitOpError() + .append("expected ", funType.getNumResults(), " result operands") + .attachNote(function.getLoc()) + .append("return type declared here"); + + for (auto pair : llvm::enumerate( + llvm::zip(function.getType().getResults(), returnOp.operands()))) { + Type type; + Value operand; + std::tie(type, operand) = pair.value(); + if (type != operand.getType()) + return returnOp.emitOpError() << "unexpected type `" << operand.getType() + << "' for operand #" << pair.index(); + } + return success(); +} + //===----------------------------------------------------------------------===// // GPUModuleOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -99,7 +99,7 @@ } // Outline the `gpu.launch` operation body into a kernel function. Replace -// `gpu.return` operations by `std.return` in the generated function. +// `gpu.terminator` operations by `gpu.return` in the generated function. static gpu::GPUFuncOp outlineKernelFunc(gpu::LaunchOp launchOp) { Location loc = launchOp.getLoc(); // Create a builder with no insertion point, insertion will happen separately @@ -116,6 +116,12 @@ builder.getUnitAttr()); outlinedFunc.body().takeBody(launchOp.body()); injectGpuIndexOperations(loc, outlinedFunc.body()); + outlinedFunc.walk([](gpu::TerminatorOp op) { + OpBuilder replacer(op); + replacer.create(op.getLoc()); + op.erase(); + }); + return outlinedFunc; } diff --git a/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir b/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir --- a/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir +++ b/mlir/test/Conversion/LoopsToGPU/linalg_to_gpu.mlir @@ -23,7 +23,7 @@ // CHECK: %[[prod_j:.*]] = muli %{{.*}}, %{{.*}} : index // CHECK: addi %{{.*}}, %[[prod_j]] : index - // CHECK: gpu.return + // CHECK: gpu.terminator } } return diff --git a/mlir/test/Conversion/LoopsToGPU/step_one.mlir b/mlir/test/Conversion/LoopsToGPU/step_one.mlir --- a/mlir/test/Conversion/LoopsToGPU/step_one.mlir +++ b/mlir/test/Conversion/LoopsToGPU/step_one.mlir @@ -73,8 +73,8 @@ // CHECK-22-NEXT: store {{.*}}, %{{.*}}[%[[i]], %[[j]], %[[ii]], %[[jj]]] : memref store %0, %B[%i, %j, %ii, %jj] : memref - // CHECK-11: gpu.return - // CHECK-22: gpu.return + // CHECK-11: gpu.terminator + // CHECK-22: gpu.terminator } } } diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir --- a/mlir/test/Dialect/GPU/canonicalize.mlir +++ b/mlir/test/Dialect/GPU/canonicalize.mlir @@ -21,7 +21,7 @@ // CHECK: "bar"(%[[inner_arg]]) "bar"(%y) : (memref) -> () - gpu.return + gpu.terminator } return } diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -376,7 +376,7 @@ // ----- module { - module @gpu_funcs attributes {gpu.kernel_module} { + gpu.module @gpu_funcs { // expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}} gpu.func @kernel_1(f32, f32) { ^bb0(%arg0: f32): @@ -428,3 +428,39 @@ } } } + +// ----- + +module { + module @gpu_funcs attributes {gpu.kernel_module} { + // expected-error @+1 {{expected memory space 5 in attribution}} + gpu.func @kernel() private(%0: memref<4xf32>) { + gpu.return + } + } +} + +// ----- + +module { + gpu.module @gpu_funcs { + // expected-note @+1 {{return type declared here}} + gpu.func @kernel() { + %0 = constant 0 : index + // expected-error @+1 {{'gpu.return' op expected 0 result operands}} + gpu.return %0 : index + } + } +} + +// ----- + +module { + gpu.module @gpu_funcs { + // expected-error @+1 {{'gpu.func' op expected void return type for kernel function}} + gpu.func @kernel() -> index kernel { + %0 = constant 0 : index + gpu.return + } + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -7,8 +7,8 @@ // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz) threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) { - // CHECK: gpu.return - gpu.return + // CHECK: gpu.terminator + gpu.terminator } return } @@ -19,8 +19,8 @@ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %blk, %grid_y = %blk, %grid_z = %blk) threads(%tx, %ty, %tz) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd) args(%kernel_arg0 = %float, %kernel_arg1 = %data) : f32, memref { - // CHECK: gpu.return - gpu.return + // CHECK: gpu.terminator + gpu.terminator } return } @@ -34,8 +34,8 @@ args(%kernel_arg0 = %float, %kernel_arg1 = %data) : f32, memref { // CHECK: "use"(%{{.*}}) "use"(%kernel_arg0): (f32) -> () - // CHECK: gpu.return - gpu.return + // CHECK: gpu.terminator + gpu.terminator } return } @@ -54,8 +54,8 @@ "use"(%val) : (index) -> () }) : () -> () }) : () -> () - // CHECK: gpu.return - gpu.return + // CHECK: gpu.terminator + gpu.terminator } return } @@ -118,11 +118,11 @@ } module @gpu_funcs attributes {gpu.kernel_module} { - // CHECK-LABEL: gpu.func @kernel_1({{.*}}: f32) -> f32 + // CHECK-LABEL: gpu.func @kernel_1({{.*}}: f32) // CHECK: workgroup // CHECK: private // CHECK: attributes - gpu.func @kernel_1(%arg0: f32) -> f32 + gpu.func @kernel_1(%arg0: f32) workgroup(%arg1: memref<42xf32, 3>) private(%arg2: memref<2xf32, 5>, %arg3: memref<1xf32, 5>) kernel diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir --- a/mlir/test/Dialect/GPU/outlining.mlir +++ b/mlir/test/Dialect/GPU/outlining.mlir @@ -31,7 +31,7 @@ "use"(%arg0): (f32) -> () "some_op"(%bx, %block_x) : (index, index) -> () %42 = load %arg1[%tx] : memref - gpu.return + gpu.terminator } return } @@ -68,14 +68,14 @@ %grid_z = %cst) threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, %block_z = %cst) { - gpu.return + gpu.terminator } // CHECK: "gpu.launch_func"(%[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]]) {kernel = "multiple_launches_kernel", kernel_module = @multiple_launches_kernel_0} : (index, index, index, index, index, index) -> () gpu.launch blocks(%bx2, %by2, %bz2) in (%grid_x2 = %cst, %grid_y2 = %cst, %grid_z2 = %cst) threads(%tx2, %ty2, %tz2) in (%block_x2 = %cst, %block_y2 = %cst, %block_z2 = %cst) { - gpu.return + gpu.terminator } return } @@ -99,7 +99,7 @@ %block_z = %cst) args(%kernel_arg0 = %cst2, %kernel_arg1 = %arg0, %kernel_arg2 = %cst3) : index, memref, index { "use"(%kernel_arg0, %kernel_arg1, %kernel_arg2) : (index, memref, index) -> () - gpu.return + gpu.terminator } return } @@ -121,19 +121,19 @@ call @device_function() : () -> () call @device_function() : () -> () %0 = llvm.mlir.addressof @global : !llvm<"i64*"> - gpu.return + gpu.terminator } return } func @device_function() { call @recursive_device_function() : () -> () - gpu.return + return } func @recursive_device_function() { call @recursive_device_function() : () -> () - gpu.return + return } // CHECK: gpu.module @function_call_kernel { @@ -141,6 +141,7 @@ // CHECK: call @device_function() : () -> () // CHECK: call @device_function() : () -> () // CHECK: llvm.mlir.addressof @global : !llvm<"i64*"> +// CHECK: gpu.return // // CHECK: llvm.mlir.global internal @global(42 : i64) : !llvm.i64 // diff --git a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir @@ -20,7 +20,7 @@ %val = sitofp %t3 : i32 to f32 %sum = "gpu.all_reduce"(%val) ({}) { op = "add" } : (f32) -> (f32) store %sum, %kernel_dst[%tz, %ty, %tx] : memref - gpu.return + gpu.terminator } %U = memref_cast %dst : memref to memref<*xf32> call @print_memref_f32(%U) : (memref<*xf32>) -> () diff --git a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir @@ -18,7 +18,7 @@ }) : (i32) -> (i32) %res = sitofp %xor : i32 to f32 store %res, %kernel_dst[%tx] : memref - gpu.return + gpu.terminator } %U = memref_cast %dst : memref to memref<*xf32> call @print_memref_f32(%U) : (memref<*xf32>) -> () diff --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir --- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir +++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir @@ -7,7 +7,7 @@ threads(%tx, %ty, %tz) in (%block_x = %cst2, %block_y = %cst, %block_z = %cst) args(%kernel_arg0 = %arg0, %kernel_arg1 = %arg1) : f32, memref { store %kernel_arg0, %kernel_arg1[%tx] : memref - gpu.return + gpu.terminator } return } diff --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir --- a/mlir/test/mlir-cuda-runner/shuffle.mlir +++ b/mlir/test/mlir-cuda-runner/shuffle.mlir @@ -21,7 +21,7 @@ br ^bb1(%m1 : f32) ^bb1(%value : f32): store %value, %kernel_dst[%tx] : memref - gpu.return + gpu.terminator } %U = memref_cast %dst : memref to memref<*xf32> call @print_memref_f32(%U) : (memref<*xf32>) -> ()