diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2950,15 +2950,15 @@ // SPIR-V OpTrait definitions //===----------------------------------------------------------------------===// -// Check that an op can only be used within the scope of a FuncOp. +// Check that an op can only be used within the scope of a function-like op. def InFunctionScope : PredOpTrait< - "op must appear in a 'func' block", - CPred<"($_op.getParentOfType())">>; + "op must appear in a function-like op's block", + CPred<"isNestedInFunctionLikeOp($_op.getParentOp())">>; -// Check that an op can only be used within the scope of a SPIR-V ModuleOp. +// Check that an op can only be used within the scope of a module-like op. def InModuleScope : PredOpTrait< - "op must appear in a 'spv.module' block", - CPred<"llvm::isa_and_nonnull($_op.getParentOp())">>; + "op must appear in a module-like op's block", + CPred<"isDirectInModuleLikeOp($_op.getParentOp())">>; //===----------------------------------------------------------------------===// // SPIR-V opcode specification diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -53,6 +53,24 @@ // Common utility functions //===----------------------------------------------------------------------===// +/// Returns true if the given op is a function-like op or nested in a +/// function-like op without a module-like op in the middle. +static bool isNestedInFunctionLikeOp(Operation *op) { + if (!op) + return false; + if (op->hasTrait()) + return false; + if (op->hasTrait()) + return true; + return isNestedInFunctionLikeOp(op->getParentOp()); +} + +/// Returns true if the given op is an module-like op that maintains a symbol +/// table. +static bool isDirectInModuleLikeOp(Operation *op) { + return op && op->hasTrait(); +} + static LogicalResult extractValueFromConstOp(Operation *op, int32_t &indexValue) { auto constOp = dyn_cast(op); diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -497,8 +497,16 @@ // ----- +// CHECK-LABEL: in_other_func_like_op +func @in_other_func_like_op() { + // CHECK: spv.Return + spv.Return +} + +// ----- + "foo.function"() ({ - // expected-error @+1 {{op must appear in a 'func' block}} + // expected-error @+1 {{op must appear in a function-like op's block}} spv.Return }) : () -> () @@ -562,7 +570,7 @@ "foo.function"() ({ %0 = spv.constant true - // expected-error @+1 {{op must appear in a 'func' block}} + // expected-error @+1 {{op must appear in a function-like op's block}} spv.ReturnValue %0 : i1 }) : () -> () diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -174,7 +174,7 @@ spv.module "Logical" "GLSL450" { func @do_nothing() -> () { - // expected-error @+1 {{'spv.EntryPoint' op failed to verify that op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv.EntryPoint "GLCompute" @do_something } } @@ -252,6 +252,14 @@ // ----- +// Allow in other module-like ops +module { + // CHECK: spv.globalVariable + spv.globalVariable @var0 : !spv.ptr +} + +// ----- + spv.module "Logical" "GLSL450" { // expected-error @+1 {{expected spv.ptr type}} spv.globalVariable @var0 : f32 @@ -275,7 +283,7 @@ spv.module "Logical" "GLSL450" { func @foo() { - // expected-error @+1 {{op failed to verify that op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv.globalVariable @var0 : !spv.ptr spv.Return } @@ -418,7 +426,7 @@ //===----------------------------------------------------------------------===// func @module_end_not_in_module() -> () { - // expected-error @+1 {{op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv._module_end } @@ -519,7 +527,7 @@ // ----- func @use_in_function() -> () { - // expected-error @+1 {{op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv.specConstant @sc = false return }