diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2997,15 +2997,15 @@ // SPIR-V OpTrait definitions //===----------------------------------------------------------------------===// -// Check that an op can only be used within the scope of a FuncOp. +// Check that an op can only be used within the scope of a function-like op. def InFunctionScope : PredOpTrait< - "op must appear in a 'func' block", - CPred<"($_op.getParentOfType())">>; + "op must appear in a function-like op's block", + CPred<"isNestedInFunctionLikeOp($_op.getParentOp())">>; -// Check that an op can only be used within the scope of a SPIR-V ModuleOp. +// Check that an op can only be used within the scope of a module-like op. def InModuleScope : PredOpTrait< - "op must appear in a 'spv.module' block", - CPred<"llvm::isa_and_nonnull($_op.getParentOp())">>; + "op must appear in a module-like op's block", + CPred<"isDirectInModuleLikeOp($_op.getParentOp())">>; //===----------------------------------------------------------------------===// // SPIR-V opcode specification diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -55,6 +55,24 @@ // Common utility functions //===----------------------------------------------------------------------===// +/// Returns true if the given op is a function-like op or nested in a +/// function-like op without a module-like op in the middle. +static bool isNestedInFunctionLikeOp(Operation *op) { + if (!op) + return false; + if (op->hasTrait()) + return false; + if (op->hasTrait()) + return true; + return isNestedInFunctionLikeOp(op->getParentOp()); +} + +/// Returns true if the given op is an module-like op that maintains a symbol +/// table. +static bool isDirectInModuleLikeOp(Operation *op) { + return op && op->hasTrait(); +} + static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) { auto constOp = dyn_cast_or_null(op); if (!constOp) { @@ -872,9 +890,9 @@ } static LogicalResult verify(spirv::AddressOfOp addressOfOp) { - auto moduleOp = addressOfOp.getParentOfType(); - auto varOp = - moduleOp.lookupSymbol(addressOfOp.variable()); + auto varOp = dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(addressOfOp.getParentOp(), + addressOfOp.variable())); if (!varOp) { return addressOfOp.emitOpError("expected spv.globalVariable symbol"); } @@ -1679,16 +1697,11 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { auto fnName = functionCallOp.callee(); - auto moduleOp = functionCallOp.getParentOfType(); - if (!moduleOp) { - return functionCallOp.emitOpError( - "must appear in a function inside 'spv.module'"); - } - - auto funcOp = moduleOp.lookupSymbol(fnName); + auto funcOp = dyn_cast_or_null(SymbolTable::lookupNearestSymbolFrom( + functionCallOp.getParentOp(), fnName)); if (!funcOp) { return functionCallOp.emitOpError("callee function '") - << fnName << "' not found in 'spv.module'"; + << fnName << "' not found in nearest symbol table"; } auto functionType = funcOp.getType(); @@ -1837,8 +1850,8 @@ if (auto init = varOp.getAttrOfType(kInitializerAttrName)) { - auto moduleOp = varOp.getParentOfType(); - auto *initOp = moduleOp.lookupSymbol(init.getValue()); + Operation *initOp = SymbolTable::lookupNearestSymbolFrom( + varOp.getParentOp(), init.getValue()); // TODO: Currently only variable initialization with specialization // constants and other variables is supported. They could be normal // constants in the module scope as well. @@ -2534,9 +2547,9 @@ } static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { - auto moduleOp = referenceOfOp.getParentOfType(); - auto specConstOp = - moduleOp.lookupSymbol(referenceOfOp.spec_const()); + auto specConstOp = dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(), + referenceOfOp.spec_const())); if (!specConstOp) { return referenceOfOp.emitOpError("expected spv.specConstant symbol"); } diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -186,6 +186,19 @@ // ----- +// Allow calling functions in other module-like ops +func @callee() { + spv.Return +} + +func @caller() { + // CHECK: spv.FunctionCall + spv.FunctionCall @callee() : () -> () + spv.Return +} + +// ----- + spv.module "Logical" "GLSL450" { func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () { // expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}} @@ -239,7 +252,7 @@ spv.module "Logical" "GLSL450" { func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 { - // expected-error @+1 {{op callee function 'f_undefined' not found in 'spv.module'}} + // expected-error @+1 {{op callee function 'f_undefined' not found in nearest symbol table}} %0 = spv.FunctionCall @f_undefined(%arg0, %arg0) : (i32, i32) -> i32 spv.Return } @@ -247,14 +260,6 @@ // ----- -func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 { - // expected-error @+1 {{must appear in a function inside 'spv.module'}} - %0 = spv.FunctionCall @f_foo(%arg0, %arg0) : (i32, i32) -> i32 - spv.Return -} - -// ----- - //===----------------------------------------------------------------------===// // spv.loop //===----------------------------------------------------------------------===// @@ -497,8 +502,16 @@ // ----- +// CHECK-LABEL: in_other_func_like_op +func @in_other_func_like_op() { + // CHECK: spv.Return + spv.Return +} + +// ----- + "foo.function"() ({ - // expected-error @+1 {{op must appear in a 'func' block}} + // expected-error @+1 {{op must appear in a function-like op's block}} spv.Return }) : () -> () @@ -562,7 +575,7 @@ "foo.function"() ({ %0 = spv.constant true - // expected-error @+1 {{op must appear in a 'func' block}} + // expected-error @+1 {{op must appear in a function-like op's block}} spv.ReturnValue %0 : i1 }) : () -> () diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -18,6 +18,16 @@ // ----- +// Allow taking address of global variables in other module-like ops +spv.globalVariable @var : !spv.ptr>, Input> +func @address_of() -> () { + // CHECK: spv._address_of @var + %1 = spv._address_of @var : !spv.ptr>, Input> + return +} + +// ----- + spv.module "Logical" "GLSL450" { spv.globalVariable @var1 : !spv.ptr>, Input> func @foo() -> () { @@ -174,7 +184,7 @@ spv.module "Logical" "GLSL450" { func @do_nothing() -> () { - // expected-error @+1 {{'spv.EntryPoint' op failed to verify that op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv.EntryPoint "GLCompute" @do_something } } @@ -229,6 +239,13 @@ // ----- +// Allow initializers coming from other module-like ops +spv.specConstant @sc = 4.0 : f32 +// CHECK: spv.globalVariable @var initializer(@sc) +spv.globalVariable @var initializer(@sc) : !spv.ptr + +// ----- + spv.module "Logical" "GLSL450" { // CHECK: spv.globalVariable @var0 bind(1, 2) : !spv.ptr spv.globalVariable @var0 bind(1, 2) : !spv.ptr @@ -252,6 +269,14 @@ // ----- +// Allow in other module-like ops +module { + // CHECK: spv.globalVariable + spv.globalVariable @var0 : !spv.ptr +} + +// ----- + spv.module "Logical" "GLSL450" { // expected-error @+1 {{expected spv.ptr type}} spv.globalVariable @var0 : f32 @@ -275,7 +300,7 @@ spv.module "Logical" "GLSL450" { func @foo() { - // expected-error @+1 {{op failed to verify that op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv.globalVariable @var0 : !spv.ptr spv.Return } @@ -418,7 +443,7 @@ //===----------------------------------------------------------------------===// func @module_end_not_in_module() -> () { - // expected-error @+1 {{op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv._module_end } @@ -461,6 +486,16 @@ // ----- +// Allow taking reference of spec constant in other module-like ops +spv.specConstant @sc = 5 : i32 +func @reference_of() { + // CHECK: spv._reference_of @sc + %0 = spv._reference_of @sc : i32 + return +} + +// ----- + spv.module "Logical" "GLSL450" { func @foo() -> () { // expected-error @+1 {{expected spv.specConstant symbol}} @@ -519,7 +554,7 @@ // ----- func @use_in_function() -> () { - // expected-error @+1 {{op must appear in a 'spv.module' block}} + // expected-error @+1 {{op must appear in a module-like op's block}} spv.specConstant @sc = false return }