diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -200,7 +200,7 @@ /// Returns the number of buffers located in the private memory. unsigned getNumPrivateAttributions() { - return getOperation()->getNumOperands() - getType().getNumInputs() - + return getBody().front().getNumArguments() - getType().getNumInputs() - getNumWorkgroupAttributions(); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -675,13 +675,10 @@ LogicalResult GPUFuncOp::verifyBody() { unsigned numFuncArguments = getNumArguments(); unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); - unsigned numPrivateAttributions = getNumPrivateAttributions(); unsigned numBlockArguments = front().getNumArguments(); - if (numBlockArguments < - numFuncArguments + numWorkgroupAttributions + numPrivateAttributions) + if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) return emitOpError() << "expected at least " - << numFuncArguments + numWorkgroupAttributions + - numPrivateAttributions + << numFuncArguments + numWorkgroupAttributions << " arguments to body region"; ArrayRef funcArgTypes = getType().getInputs(); diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -423,3 +423,15 @@ } } } + +// ----- + +module { + gpu.module @gpu_funcs { + // expected-error @+1 {{'gpu.func' op expected at least 5 arguments to body region}} + "gpu.func"() ( { + ^bb0(%arg0: f32, %arg1: memref, %arg2: memref<5xf32, 3>, %arg3: memref<5xf32, 5>): + "gpu.return"() : () -> () + } ) {gpu.kernel, sym_name = "kernel_1", type = (f32, memref) -> (), workgroup_attributions = 3: i64} : () -> () + } +} diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -132,4 +132,11 @@ } } + gpu.module @explicit_attributions { + // CHECK-LABEL: gpu.func @kernel_1({{.*}}: f32, {{.*}}: memref) workgroup({{.*}}: memref<5xf32, 3>) private({{.*}}: memref<5xf32, 5>) + "gpu.func"() ( { + ^bb0(%arg0: f32, %arg1: memref, %arg2: memref<5xf32, 3>, %arg3: memref<5xf32, 5>): + "gpu.return"() : () -> () + } ) {gpu.kernel, sym_name = "kernel_1", type = (f32, memref) -> (), workgroup_attributions = 1: i64} : () -> () + } }