diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -128,11 +128,6 @@ nullptr; } - /// Returns the type of the function this Op defines. - FunctionType getType() { - return getTypeAttr().getValue().cast(); - } - /// Change the type of this function in place. This is an extremely /// dangerous operation and it is up to the caller to ensure that this is /// legal for this function, and to restore invariants: diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -25,7 +25,7 @@ /// For composite types, this converter additionally performs type wrapping to /// satisfy shader interface requirements: shader interface types must be /// pointers to structs. -class SPIRVTypeConverter final : public TypeConverter { +class SPIRVTypeConverter : public TypeConverter { public: using TypeConverter::TypeConverter; @@ -59,6 +59,7 @@ namespace spirv { class AccessChainOp; +class FuncOp; class SPIRVConversionTarget : public ConversionTarget { public: @@ -104,7 +105,8 @@ /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its /// arguments. -LogicalResult setABIAttrs(FuncOp funcOp, EntryPointABIAttr entryPointInfo, +LogicalResult setABIAttrs(spirv::FuncOp funcOp, + EntryPointABIAttr entryPointInfo, ArrayRef argABIInfo); } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h @@ -43,4 +43,40 @@ } // end namespace spirv } // end namespace mlir +namespace llvm { + +/// spirv::Function ops hash just like pointers. +template <> +struct DenseMapInfo { + static mlir::spirv::FuncOp getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::spirv::FuncOp::getFromOpaquePointer(pointer); + } + static mlir::spirv::FuncOp getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::spirv::FuncOp::getFromOpaquePointer(pointer); + } + static unsigned getHashValue(mlir::spirv::FuncOp val) { + return hash_value(val.getAsOpaquePointer()); + } + static bool isEqual(mlir::spirv::FuncOp LHS, mlir::spirv::FuncOp RHS) { + return LHS == RHS; + } +}; + +/// Allow stealing the low bits of spirv::Function ops. +template <> +struct PointerLikeTypeTraits { +public: + static inline void *getAsVoidPointer(mlir::spirv::FuncOp I) { + return const_cast(I.getAsOpaquePointer()); + } + static inline mlir::spirv::FuncOp getFromVoidPointer(void *P) { + return mlir::spirv::FuncOp::getFromOpaquePointer(P); + } + static constexpr int NumLowBitsAvailable = 3; +}; + +} // namespace llvm + #endif // MLIR_DIALECT_SPIRV_SPIRVOPS_H_ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -210,7 +210,7 @@ let autogenSerialization = 0; let builders = [OpBuilder<[{Builder *builder, OperationState &state, - FuncOp function, + spirv::FuncOp function, spirv::ExecutionMode executionMode, ArrayRef params}]>]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -15,8 +15,11 @@ #ifndef SPIRV_STRUCTURE_OPS #define SPIRV_STRUCTURE_OPS +include "mlir/Analysis/CallInterfaces.td" include "mlir/Dialect/SPIRV/SPIRVBase.td" +// ----- + def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> { let summary = "Get the address of a global variable."; @@ -61,6 +64,8 @@ let assemblyFormat = "$variable attr-dict `:` type($pointer)"; } +// ----- + def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { let summary = "The op that declares a SPIR-V normal constant"; @@ -125,6 +130,8 @@ let autogenSerialization = 0; } +// ----- + def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> { let summary = [{ Declare an entry point, its execution model, and its interface. @@ -182,10 +189,93 @@ let builders = [OpBuilder<[{Builder *builder, OperationState &state, spirv::ExecutionModel executionModel, - FuncOp function, + spirv::FuncOp function, ArrayRef interfaceVars}]>]; } +// ----- + +def SPV_FuncOp : SPV_Op<"func", [ + DeclareOpInterfaceMethods, + FunctionLike, InModuleScope, IsolatedFromAbove, Symbol + ]> { + let summary = "Declare or define a function"; + + let description = [{ + This op declares or defines a SPIR-V function using one region, which + contains one or more blocks. + + Different from the SPIR-V binary format, this op is not allowed to + implicitly capture global values, and all external references must use + function arguments or symbol references. This op itself defines a symbol + that is unique in the enclosing module op. + + This op itself takes no operands and generates no results. Its region + can take zero or more arguments and return zero or one values. + + ### Custom assembly form + + ``` + spv-function-control ::= "None" | "Inline" | "DontInline" | ... + spv-function-op ::= `spv.func` function-signature + spv-function-control region + ``` + + For example: + + ``` + spv.func @foo() -> () "None" { ... } + spv.func @bar() -> () "Inline|Pure" { ... } + ``` + }]; + + let arguments = (ins + TypeAttr:$type, + StrAttr:$sym_name, + SPV_FunctionControlAttr:$function_control + ); + + let results = (outs); + + let regions = (region AnyRegion:$body); + + let verifier = [{ return success(); }]; + + let builders = [OpBuilder<[{ + Builder *, OperationState &state, + StringRef name, FunctionType type, + spirv::FunctionControl control = spirv::FunctionControl::None, + ArrayRef attrs = {} + }]>]; + + let hasOpcode = 0; + + let autogenSerialization = 0; + + let extraClassDeclaration = [{ + private: + // This trait needs access to the hooks defined below. + friend class OpTrait::FunctionLike; + + /// Returns the number of arguments. Hook for OpTrait::FunctionLike. + unsigned getNumFuncArguments() { return getType().getNumInputs(); } + + /// Returns the number of results. Hook for OpTrait::FunctionLike. + unsigned getNumFuncResults() { return getType().getNumResults(); } + + /// Hook for OpTrait::FunctionLike, called after verifying that the 'type' + /// attribute is present and checks if it holds a function type. Ensures + /// getType, getNumFuncArguments, and getNumFuncResults can be called safely + LogicalResult verifyType(); + + /// Hook for OpTrait::FunctionLike, called after verifying the function + /// type and the presence of the (potentially empty) function body. + /// Ensures SPIR-V specific semantics. + LogicalResult verifyBody(); + }]; +} + +// ----- def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> { let summary = [{ @@ -238,6 +328,8 @@ OptionalAttr:$initializer ); + let results = (outs); + let builders = [ OpBuilder<"Builder *builder, OperationState &state, " "TypeAttr type, ArrayRef namedAttrs", [{ @@ -251,8 +343,6 @@ Type type, StringRef name, spirv::BuiltIn builtin}]> ]; - let results = (outs); - let hasOpcode = 0; let autogenSerialization = 0; @@ -264,10 +354,12 @@ }]; } +// ----- + def SPV_ModuleOp : SPV_Op<"module", [IsolatedFromAbove, SingleBlockImplicitTerminator<"ModuleEndOp">, - NativeOpTrait<"SymbolTable">]> { + SymbolTable]> { let summary = "The top-level op that defines a SPIR-V module"; let description = [{ @@ -351,6 +443,8 @@ }]; } +// ----- + def SPV_ModuleEndOp : SPV_Op<"_module_end", [InModuleScope, Terminator]> { let summary = "The pseudo op that ends a SPIR-V module"; @@ -375,6 +469,8 @@ let autogenSerialization = 0; } +// ----- + def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> { let summary = "Reference a specialization constant."; @@ -415,6 +511,8 @@ let assemblyFormat = "$spec_const attr-dict `:` type($reference)"; } +// ----- + def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> { let summary = "The op that declares a SPIR-V specialization constant"; @@ -462,4 +560,6 @@ let autogenSerialization = 0; } +// ----- + #endif // SPIRV_STRUCTURE_OPS diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -63,34 +63,6 @@ /// `argIndices` is allowed to have duplicates and can be in any order. void eraseArguments(ArrayRef argIndices); - /// Returns the type of this function. - FunctionType getType() { - return getAttrOfType(getTypeAttrName()) - .getValue() - .cast(); - } - - /// Change the type of this function in place. This is an extremely dangerous - /// operation and it is up to the caller to ensure that this is legal for this - /// function, and to restore invariants: - /// - the entry block args must be updated to match the function params. - /// - the argument/result attributes may need an update: if the new type has - /// less parameters we drop the extra attributes, if there are more - /// parameters they won't have any attributes. - void setType(FunctionType newType) { - SmallVector nameBuf; - auto oldType = getType(); - for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; - i++) { - removeAttr(getArgAttrName(i, nameBuf)); - } - for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; - i++) { - removeAttr(getResultAttrName(i, nameBuf)); - } - setAttr(getTypeAttrName(), TypeAttr::get(newType)); - } - /// Create a deep copy of this function and all of its blocks, remapping /// any operands that use values outside of the function using the map that is /// provided (leaving them alone if no entry is present). If the mapper @@ -105,19 +77,6 @@ /// the attributes of the current function and dest are compatible. void cloneInto(FuncOp dest, BlockAndValueMapping &mapper); - //===--------------------------------------------------------------------===// - // Body Handling - //===--------------------------------------------------------------------===// - - /// Add an entry block to an empty function, and set up the block arguments - /// to match the signature of the function. The newly inserted entry block is - /// returned. - Block *addEntryBlock(); - - /// Add a normal block to the end of the function's block list. The function - /// should at least already have an entry block. - Block *addBlock(); - //===--------------------------------------------------------------------===// // CallableOpInterface //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -137,6 +137,19 @@ Block &back() { return getBody().back(); } Block &front() { return getBody().front(); } + /// Add an entry block to an empty function, and set up the block arguments + /// to match the signature of the function. The newly inserted entry block + /// is returned. + /// + /// Note that the concrete class must define a method with the same name to + /// hide this one if the concrete class does not use FunctionType for the + /// function type under the hood. + Block *addEntryBlock(); + + /// Add a normal block to the end of the function's block list. The function + /// should at least already have an entry block. + Block *addBlock(); + /// Hook for concrete ops to verify the contents of the body. Called as a /// part of trait verification, after type verification and ensuring that a /// region exists. @@ -154,6 +167,15 @@ getTypeAttrName()); } + /// Return the type of this function. + /// + /// Note that the concrete class must define a method with the same name to + /// hide this one if the concrete class does not use FunctionType for the + /// function type under the hood. + FunctionType getType() { + return getTypeAttr().getValue().template cast(); + } + bool isTypeAttrValid() { auto typeAttr = getTypeAttr(); if (!typeAttr) @@ -161,6 +183,19 @@ return typeAttr.getValue() != Type{}; } + /// Change the type of this function in place. This is an extremely dangerous + /// operation and it is up to the caller to ensure that this is legal for this + /// function, and to restore invariants: + /// - the entry block args must be updated to match the function params. + /// - the argument/result attributes may need an update: if the new type + /// has less parameters we drop the extra attributes, if there are more + /// parameters they won't have any attributes. + /// + /// Note that the concrete class must define a method with the same name to + /// hide this one if the concrete class does not use FunctionType for the + /// function type under the hood. + void setType(FunctionType newType); + //===--------------------------------------------------------------------===// // Argument Handling //===--------------------------------------------------------------------===// @@ -417,6 +452,43 @@ return funcOp.verifyBody(); } +//===----------------------------------------------------------------------===// +// Function Body. +//===----------------------------------------------------------------------===// + +template +Block *FunctionLike::addEntryBlock() { + assert(empty() && "function already has an entry block"); + auto *entry = new Block(); + push_back(entry); + entry->addArguments(getType().getInputs()); + return entry; +} + +template +Block *FunctionLike::addBlock() { + assert(!empty() && "function should at least have an entry block"); + push_back(new Block()); + return &back(); +} + +//===----------------------------------------------------------------------===// +// Function Type Attribute. +//===----------------------------------------------------------------------===// + +template +void FunctionLike::setType(FunctionType newType) { + SmallVector nameBuf; + auto oldType = getType(); + auto *concreteOp = static_cast(this); + + for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++) + concreteOp->removeAttr(getArgAttrName(i, nameBuf)); + for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++) + concreteOp->removeAttr(getResultAttrName(i, nameBuf)); + concreteOp->setAttr(getTypeAttrName(), TypeAttr::get(newType)); +} + //===----------------------------------------------------------------------===// // Function Argument Attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -273,7 +273,7 @@ //===----------------------------------------------------------------------===// // Legalizes a GPU function as an entry SPIR-V function. -static FuncOp +static spirv::FuncOp lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, spirv::EntryPointABIAttr entryPointInfo, @@ -299,11 +299,10 @@ signatureConverter.addInputs(argType.index(), convertedType); } } - auto newFuncOp = rewriter.create( + auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), - llvm::None), - ArrayRef()); + llvm::None)); for (const auto &namedAttr : funcOp.getAttrs()) { if (namedAttr.first.is(impl::getTypeAttrName()) || namedAttr.first.is(SymbolTable::getSymbolAttrName())) @@ -336,8 +335,8 @@ auto context = rewriter.getContext(); auto entryPointAttr = spirv::getEntryPointABIAttr(workGroupSizeAsInt32, context); - FuncOp newFuncOp = lowerAsEntryFunction(funcOp, typeConverter, rewriter, - entryPointAttr, argABI); + spirv::FuncOp newFuncOp = lowerAsEntryFunction( + funcOp, typeConverter, rewriter, entryPointAttr, argABI); if (!newFuncOp) { return matchFailure(); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -75,8 +75,6 @@ std::unique_ptr target = spirv::SPIRVConversionTarget::get( spirv::lookupTargetEnvOrDefault(module), context); - target->addDynamicallyLegalOp( - [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); if (failed(applyFullConversion(kernelModules, *target, patterns, &typeConverter))) { diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -38,8 +38,6 @@ std::unique_ptr target = spirv::SPIRVConversionTarget::get( spirv::lookupTargetEnvOrDefault(module), context); - target->addDynamicallyLegalOp( - [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); if (failed(applyPartialConversion(module, *target, patterns))) { return signalPassFailure(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -59,10 +59,11 @@ /// 'dest' that is attached to an operation registered to the current dialect. bool isLegalToInline(Region *dest, Region *src, BlockAndValueMapping &) const final { - // Return true here when inlining into spv.selection and spv.loop - // operations. + // Return true here when inlining into spv.func, spv.selection, and + // spv.loop operations. auto op = dest->getParentOp(); - return isa(op) || isa(op); + return isa(op) || isa(op) || + isa(op); } /// Returns true if the given operation 'op', that is registered to this diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -191,6 +191,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); + // TODO(antiagainst): support converting functions with one result. if (fnType.getNumResults()) return matchFailure(); @@ -202,12 +203,23 @@ signatureConverter.addInputs(argType.index(), convertedType); } - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(rewriter.getFunctionType( - signatureConverter.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); - }); + // Create the converted spv.func op. + auto newFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), + rewriter.getFunctionType(signatureConverter.getConvertedTypes(), + llvm::None)); + + // Copy over all attributes other than the function name and type. + for (const auto &namedAttr : funcOp.getAttrs()) { + if (!namedAttr.first.is(impl::getTypeAttrName()) && + !namedAttr.first.is(SymbolTable::getSymbolAttrName())) + newFuncOp.setAttr(namedAttr.first, namedAttr.second); + } + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + rewriter.eraseOp(funcOp); return matchSuccess(); } @@ -308,7 +320,8 @@ auto indexType = typeConverter.getIndexType(builder.getContext()); Value ptrLoc = nullptr; - assert(indices.size() == strides.size()); + assert(indices.size() == strides.size() && + "must provide indices for all dimensions"); for (auto index : enumerate(indices)) { Value strideVal = builder.create( loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); @@ -330,7 +343,8 @@ //===----------------------------------------------------------------------===// LogicalResult -mlir::spirv::setABIAttrs(FuncOp funcOp, spirv::EntryPointABIAttr entryPointInfo, +mlir::spirv::setABIAttrs(spirv::FuncOp funcOp, + spirv::EntryPointABIAttr entryPointInfo, ArrayRef argABIInfo) { // Set the attributes for argument and the function. StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -1546,7 +1547,7 @@ void spirv::EntryPointOp::build(Builder *builder, OperationState &state, spirv::ExecutionModel executionModel, - FuncOp function, + spirv::FuncOp function, ArrayRef interfaceVars) { build(builder, state, builder->getI32IntegerAttr(static_cast(executionModel)), @@ -1607,7 +1608,7 @@ //===----------------------------------------------------------------------===// void spirv::ExecutionModeOp::build(Builder *builder, OperationState &state, - FuncOp function, + spirv::FuncOp function, spirv::ExecutionMode executionMode, ArrayRef params) { build(builder, state, builder->getSymbolRefAttr(function), @@ -1640,9 +1641,10 @@ } static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) { - printer << spirv::ExecutionModeOp::getOperationName() << " @" - << execModeOp.fn() << " \"" - << stringifyExecutionMode(execModeOp.execution_mode()) << "\""; + printer << spirv::ExecutionModeOp::getOperationName() << " "; + printer.printSymbolName(execModeOp.fn()); + printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode()) + << "\""; auto values = execModeOp.values(); if (!values.size()) return; @@ -1652,6 +1654,136 @@ }); } +//===----------------------------------------------------------------------===// +// spv.func +//===----------------------------------------------------------------------===// + +static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) { + SmallVector entryArgs; + SmallVector, 4> argAttrs; + SmallVector, 4> resultAttrs; + SmallVector argTypes; + SmallVector resultTypes; + auto &builder = parser.getBuilder(); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + state.attributes)) + return failure(); + + // Parse the function signature. + bool isVariadic = false; + if (impl::parseFunctionSignature(parser, /*allowVariadic=*/false, entryArgs, + argTypes, argAttrs, isVariadic, resultTypes, + resultAttrs)) + return failure(); + + auto fnType = builder.getFunctionType(argTypes, resultTypes); + state.addAttribute(impl::getTypeAttrName(), TypeAttr::get(fnType)); + + // Parse the optional function control keyword. + spirv::FunctionControl fnControl; + if (parseEnumAttribute(fnControl, parser, state)) + return failure(); + + // If additional attributes are present, parse them. + if (parser.parseOptionalAttrDictWithKeyword(state.attributes)) + return failure(); + + // Add the attributes to the function arguments. + assert(argAttrs.size() == argTypes.size()); + assert(resultAttrs.size() == resultTypes.size()); + impl::addArgAndResultAttrs(builder, state, argAttrs, resultAttrs); + + // Parse the optional function body. + auto *body = state.addRegion(); + return parser.parseOptionalRegion( + *body, entryArgs, entryArgs.empty() ? ArrayRef() : argTypes); +} + +static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) { + // Print function name, signature, and control. + printer << spirv::FuncOp::getOperationName() << " "; + printer.printSymbolName(fnOp.sym_name()); + auto fnType = fnOp.getType(); + impl::printFunctionSignature(printer, fnOp, fnType.getInputs(), + /*isVariadic=*/false, fnType.getResults()); + printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control()) + << "\""; + impl::printFunctionAttributes( + printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(), + {spirv::attributeName()}); + + // Print the body if this is not an external function. + Region &body = fnOp.body(); + if (!body.empty()) + printer.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); +} + +LogicalResult spirv::FuncOp::verifyType() { + auto type = getTypeAttr().getValue(); + if (!type.isa()) + return emitOpError("requires '" + getTypeAttrName() + + "' attribute of function type"); + if (getType().getNumResults() > 1) + return emitOpError("cannot have more than one result"); + return success(); +} + +LogicalResult spirv::FuncOp::verifyBody() { + FunctionType fnType = getType(); + + auto walkResult = walk([fnType](Operation *op) -> WalkResult { + if (auto retOp = dyn_cast(op)) { + if (fnType.getNumResults() != 0) + return retOp.emitOpError("cannot be used in functions returning value"); + } else if (auto retOp = dyn_cast(op)) { + if (fnType.getNumResults() != 1) + return retOp.emitOpError( + "returns 1 value but enclosing function requires ") + << fnType.getNumResults() << " results"; + + auto retOperandType = retOp.value().getType(); + auto fnResultType = fnType.getResult(0); + if (retOperandType != fnResultType) + return retOp.emitOpError(" return value's type (") + << retOperandType << ") mismatch with function's result type (" + << fnResultType << ")"; + } + return WalkResult::advance(); + }); + + // TODO(antiagainst): verify other bits like linkage type. + + return failure(walkResult.wasInterrupted()); +} + +void spirv::FuncOp::build(Builder *builder, OperationState &state, + StringRef name, FunctionType type, + spirv::FunctionControl control, + ArrayRef attrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder->getStringAttr(name)); + state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + state.addAttribute( + spirv::attributeName(), + builder->getI32IntegerAttr(static_cast(control))); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); +} + +// CallableOpInterface +Region *spirv::FuncOp::getCallableRegion() { + return isExternal() ? nullptr : &body(); +} + +// CallableOpInterface +ArrayRef spirv::FuncOp::getCallableResults() { + return getType().getResults(); +} + //===----------------------------------------------------------------------===// // spv.FunctionCall //===----------------------------------------------------------------------===// @@ -1659,8 +1791,9 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { auto fnName = functionCallOp.callee(); - auto funcOp = dyn_cast_or_null(SymbolTable::lookupNearestSymbolFrom( - functionCallOp.getParentOp(), fnName)); + auto funcOp = + dyn_cast_or_null(SymbolTable::lookupNearestSymbolFrom( + functionCallOp.getParentOp(), fnName)); if (!funcOp) { return functionCallOp.emitOpError("callee function '") << fnName << "' not found in nearest symbol table"; @@ -2322,69 +2455,61 @@ auto &op = *moduleOp.getOperation(); auto *dialect = op.getDialect(); auto &body = op.getRegion(0).front(); - DenseMap, spirv::EntryPointOp> + DenseMap, spirv::EntryPointOp> entryPoints; SymbolTable table(moduleOp); for (auto &op : body) { - if (op.getDialect() == dialect) { - // For EntryPoint op, check that the function and execution model is not - // duplicated in EntryPointOps. Also verify that the interface specified - // comes from globalVariables here to make this check cheaper. - if (auto entryPointOp = dyn_cast(op)) { - auto funcOp = table.lookup(entryPointOp.fn()); - if (!funcOp) { - return entryPointOp.emitError("function '") - << entryPointOp.fn() << "' not found in 'spv.module'"; - } - if (auto interface = entryPointOp.interface()) { - for (Attribute varRef : interface) { - auto varSymRef = varRef.dyn_cast(); - if (!varSymRef) { - return entryPointOp.emitError( - "expected symbol reference for interface " - "specification instead of '") - << varRef; - } - auto variableOp = - table.lookup(varSymRef.getValue()); - if (!variableOp) { - return entryPointOp.emitError("expected spv.globalVariable " - "symbol reference instead of'") - << varSymRef << "'"; - } + if (op.getDialect() != dialect) + return op.emitError("'spv.module' can only contain spv.* ops"); + + // For EntryPoint op, check that the function and execution model is not + // duplicated in EntryPointOps. Also verify that the interface specified + // comes from globalVariables here to make this check cheaper. + if (auto entryPointOp = dyn_cast(op)) { + auto funcOp = table.lookup(entryPointOp.fn()); + if (!funcOp) { + return entryPointOp.emitError("function '") + << entryPointOp.fn() << "' not found in 'spv.module'"; + } + if (auto interface = entryPointOp.interface()) { + for (Attribute varRef : interface) { + auto varSymRef = varRef.dyn_cast(); + if (!varSymRef) { + return entryPointOp.emitError( + "expected symbol reference for interface " + "specification instead of '") + << varRef; + } + auto variableOp = + table.lookup(varSymRef.getValue()); + if (!variableOp) { + return entryPointOp.emitError("expected spv.globalVariable " + "symbol reference instead of'") + << varSymRef << "'"; } } - - auto key = std::pair( - funcOp, entryPointOp.execution_model()); - auto entryPtIt = entryPoints.find(key); - if (entryPtIt != entryPoints.end()) { - return entryPointOp.emitError("duplicate of a previous EntryPointOp"); - } - entryPoints[key] = entryPointOp; } - continue; - } - - auto funcOp = dyn_cast(op); - if (!funcOp) - return op.emitError("'spv.module' can only contain func and spv.* ops"); - - if (funcOp.isExternal()) - return op.emitError("'spv.module' cannot contain external functions"); - for (auto &block : funcOp) - for (auto &op : block) { - if (op.getDialect() == dialect) - continue; - - if (isa(op)) - return op.emitError("'spv.module' cannot contain nested functions"); - - return op.emitError( - "functions in 'spv.module' can only contain spv.* ops"); + auto key = std::pair( + funcOp, entryPointOp.execution_model()); + auto entryPtIt = entryPoints.find(key); + if (entryPtIt != entryPoints.end()) { + return entryPointOp.emitError("duplicate of a previous EntryPointOp"); } + entryPoints[key] = entryPointOp; + } else if (auto funcOp = dyn_cast(op)) { + if (funcOp.isExternal()) + return op.emitError("'spv.module' cannot contain external functions"); + + // TODO(antiagainst): move this check to spv.func. + for (auto &block : funcOp) + for (auto &op : block) { + if (op.getDialect() != dialect) + return op.emitError( + "functions in 'spv.module' can only contain spv.* ops"); + } + } } // Verify capabilities. ODS already guarantees that we have an array of @@ -2434,12 +2559,7 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(spirv::ReturnOp returnOp) { - auto funcOp = returnOp.getParentOfType(); - auto numOutputs = funcOp.getType().getNumResults(); - if (numOutputs != 0) - return returnOp.emitOpError("cannot be used in functions returning value") - << (numOutputs > 1 ? "s" : ""); - + // Verification is performed in spv.func op. return success(); } @@ -2448,20 +2568,7 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(spirv::ReturnValueOp retValOp) { - auto funcOp = retValOp.getParentOfType(); - auto numFnResults = funcOp.getType().getNumResults(); - if (numFnResults != 1) - return retValOp.emitOpError( - "returns 1 value but enclosing function requires ") - << numFnResults << " results"; - - auto operandType = retValOp.value().getType(); - auto fnResultType = funcOp.getType().getResult(0); - if (operandType != fnResultType) - return retValOp.emitOpError(" return value's type (") - << operandType << ") mismatch with function's result type (" - << fnResultType << ")"; - + // Verification is performed in spv.func op. return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -48,7 +48,8 @@ /// Returns true if the given `block` is a function entry block. static inline bool isFnEntryBlock(Block *block) { - return block->isEntryBlock() && isa_and_nonnull(block->getParentOp()); + return block->isEntryBlock() && + isa_and_nonnull(block->getParentOp()); } namespace { @@ -134,8 +135,8 @@ /// Processes an OpMemberName instruction. LogicalResult processMemberName(ArrayRef words); - /// Gets the FuncOp associated with a result of OpFunction. - FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } + /// Gets the function op associated with a result of OpFunction. + spirv::FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } /// Processes the SPIR-V function at the current `offset` into `binary`. /// The operands to the OpFunction instruction is passed in as ``operands`. @@ -392,7 +393,7 @@ Optional module; /// The current function under construction. - Optional curFunction; + Optional curFunction; /// The current block under construction. Block *curBlock = nullptr; @@ -425,7 +426,7 @@ DenseMap globalVariableMap; // Result to function mapping. - DenseMap funcMap; + DenseMap funcMap; // Result to block mapping. DenseMap blockMap; @@ -775,8 +776,8 @@ } std::string fnName = getFunctionSymbol(operands[1]); - auto funcOp = opBuilder.create(unknownLoc, fnName, functionType, - ArrayRef()); + auto funcOp = + opBuilder.create(unknownLoc, fnName, functionType); curFunction = funcMap[operands[1]] = funcOp; LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = " << fnType << ", id = " << operands[1] << ") --\n"); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -179,7 +179,7 @@ LogicalResult processName(uint32_t resultID, StringRef name); /// Processes a SPIR-V function op. - LogicalResult processFuncOp(FuncOp op); + LogicalResult processFuncOp(spirv::FuncOp op); LogicalResult processVariableOp(spirv::VariableOp op); @@ -682,7 +682,7 @@ } } // namespace -LogicalResult Serializer::processFuncOp(FuncOp op) { +LogicalResult Serializer::processFuncOp(spirv::FuncOp op) { LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n"); assert(functionHeader.empty() && functionBody.empty()); @@ -1642,7 +1642,7 @@ return processBranchConditionalOp(op); }) .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) - .Case([&](FuncOp op) { return processFuncOp(op); }) + .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) .Case([&](spirv::GlobalVariableOp op) { return processGlobalVariableOp(op); }) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" -#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SetVector.h" @@ -30,7 +29,8 @@ /// Creates a global variable for an argument based on the ABI info. static spirv::GlobalVariableOp -createGlobalVariableForArg(FuncOp funcOp, OpBuilder &builder, unsigned argNum, +createGlobalVariableForArg(spirv::FuncOp funcOp, OpBuilder &builder, + unsigned argNum, spirv::InterfaceVarABIAttr abiInfo) { auto spirvModule = funcOp.getParentOfType(); if (!spirvModule) { @@ -70,7 +70,7 @@ /// Gets the global variables that need to be specified as interface variable /// with an spv.EntryPointOp. Traverses the body of a entry function to do so. static LogicalResult -getInterfaceVariables(FuncOp funcOp, +getInterfaceVariables(spirv::FuncOp funcOp, SmallVectorImpl &interfaceVars) { auto module = funcOp.getParentOfType(); if (!module) { @@ -97,7 +97,8 @@ } /// Lowers the entry point attribute. -static LogicalResult lowerEntryPointABIAttr(FuncOp funcOp, OpBuilder &builder) { +static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, + OpBuilder &builder) { auto entryPointAttrName = spirv::getEntryPointABIAttrName(); auto entryPointAttr = funcOp.getAttrOfType(entryPointAttrName); @@ -127,13 +128,18 @@ } namespace { -/// Pattern rewriter for changing function signature to match the ABI specified -/// in attributes. -class FuncOpLowering final : public SPIRVOpLowering { +/// A pattern to convert function signature according to interface variable ABI +/// attributes. +/// +/// Specifically, this pattern creates global variables according to interface +/// variable ABI attributes attached to function arguments and converts all +/// function argument uses to those global variables. This is necessary because +/// Vulkan requires all shader entry points to be of void(void) type. +class ProcessInterfaceVarABI final : public SPIRVOpLowering { public: - using SPIRVOpLowering::SPIRVOpLowering; + using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -145,9 +151,9 @@ }; } // namespace -PatternMatchResult -FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { +PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite( + spirv::FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { if (!funcOp.getAttrOfType( spirv::getEntryPointABIAttrName())) { // TODO(ravishankarm) : Non-entry point functions are not handled. @@ -185,8 +191,7 @@ // before the use. There might be multiple loads and currently there is no // easy way to replace all uses with a sequence of operations. if (isScalarOrVectorType(argType.value())) { - auto indexType = - typeConverter.convertType(IndexType::get(funcOp.getContext())); + auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext()); auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter); auto loadPtr = rewriter.create( @@ -213,26 +218,33 @@ SPIRVTypeConverter typeConverter; OwningRewritePatternList patterns; - patterns.insert(context, typeConverter); + patterns.insert(context, typeConverter); - std::unique_ptr target = spirv::SPIRVConversionTarget::get( - spirv::lookupTargetEnvOrDefault(module), context); - auto entryPointAttrName = spirv::getEntryPointABIAttrName(); - target->addDynamicallyLegalOp([&](FuncOp op) { - return op.getAttrOfType(entryPointAttrName) && - op.getNumResults() == 0 && op.getNumArguments() == 0; + ConversionTarget target(*context); + // "Legal" function ops should have no interface variable ABI attributes. + target.addDynamicallyLegalOp([&](spirv::FuncOp op) { + StringRef attrName = spirv::getInterfaceVarABIAttrName(); + for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) + if (op.getArgAttr(i, attrName)) + return false; + return true; + }); + // All other SPIR-V ops are legal. + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return op->getDialect()->getNamespace() == + spirv::SPIRVDialect::getDialectNamespace(); }); - target->addLegalOp(); if (failed( - applyPartialConversion(module, *target, patterns, &typeConverter))) { + applyPartialConversion(module, target, patterns, &typeConverter))) { return signalPassFailure(); } // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point // attributes. OpBuilder builder(context); - SmallVector entryPointFns; - module.walk([&](FuncOp funcOp) { + SmallVector entryPointFns; + auto entryPointAttrName = spirv::getEntryPointABIAttrName(); + module.walk([&](spirv::FuncOp funcOp) { if (funcOp.getAttrOfType(entryPointAttrName)) { entryPointFns.push_back(funcOp); } diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -139,24 +139,6 @@ entry.eraseArgument(originalNumArgs - i - 1); } -/// Add an entry block to an empty function, and set up the block arguments -/// to match the signature of the function. -Block *FuncOp::addEntryBlock() { - assert(empty() && "function already has an entry block"); - auto *entry = new Block(); - push_back(entry); - entry->addArguments(getType().getInputs()); - return entry; -} - -/// Add a normal block to the end of the function's block list. The function -/// should at least already have an entry block. -Block *FuncOp::addBlock() { - assert(!empty() && "function should at least have an entry block"); - push_back(new Block()); - return &back(); -} - /// Clone the internal blocks from this function into dest and all attributes /// from this function to dest. void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -21,7 +21,7 @@ // CHECK-DAG: spv.globalVariable [[NUMWORKGROUPSVAR:@.*]] built_in("NumWorkgroups") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr, Input> - // CHECK-LABEL: func @load_store_kernel + // CHECK-LABEL: spv.func @load_store_kernel // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -1,14 +1,13 @@ // RUN: mlir-opt -pass-pipeline='convert-gpu-to-spirv{workgroup-size=32,4}' %s -o - | FileCheck %s module attributes {gpu.container_module} { - gpu.module @kernels { // CHECK: spv.module "Logical" "GLSL450" { - // CHECK-LABEL: func @kernel_1 + // CHECK-LABEL: spv.func @kernel // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}} // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} - gpu.func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32>) attributes {gpu.kernel} { + gpu.func @kernel(%arg0 : f32, %arg1 : memref<12xf32>) attributes {gpu.kernel} { // CHECK: spv.Return gpu.return } @@ -19,7 +18,7 @@ %0 = "op"() : () -> (f32) %1 = "op"() : () -> (memref<12xf32>) %cst = constant 1 : index - "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = "kernel_1", kernel_module = @kernels } + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = "kernel", kernel_module = @kernels } : (index, index, index, index, index, index, f32, memref<12xf32>) -> () return } diff --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir --- a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir +++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir @@ -26,7 +26,7 @@ // CHECK: spv.globalVariable // CHECK-SAME: built_in("LocalInvocationId") -// CHECK: func @single_workgroup_reduction +// CHECK: @single_workgroup_reduction // CHECK-SAME: (%[[INPUT:.+]]: !spv.ptr{{.+}}, %[[OUTPUT:.+]]: !spv.ptr{{.+}}) // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -14,75 +14,77 @@ } // CHECK-LABEL: @fadd_scalar -func @fadd_scalar(%arg: f32) -> f32 { +func @fadd_scalar(%arg: f32) { // CHECK: spv.FAdd %0 = addf %arg, %arg : f32 - return %0 : f32 + return } // CHECK-LABEL: @fdiv_scalar -func @fdiv_scalar(%arg: f32) -> f32 { +func @fdiv_scalar(%arg: f32) { // CHECK: spv.FDiv %0 = divf %arg, %arg : f32 - return %0 : f32 + return } // CHECK-LABEL: @fmul_scalar -func @fmul_scalar(%arg: f32) -> f32 { +func @fmul_scalar(%arg: f32) { // CHECK: spv.FMul %0 = mulf %arg, %arg : f32 - return %0 : f32 + return } // CHECK-LABEL: @fmul_vector2 -func @fmul_vector2(%arg: vector<2xf32>) -> vector<2xf32> { +func @fmul_vector2(%arg: vector<2xf32>) { // CHECK: spv.FMul %0 = mulf %arg, %arg : vector<2xf32> - return %0 : vector<2xf32> + return } // CHECK-LABEL: @fmul_vector3 -func @fmul_vector3(%arg: vector<3xf32>) -> vector<3xf32> { +func @fmul_vector3(%arg: vector<3xf32>) { // CHECK: spv.FMul %0 = mulf %arg, %arg : vector<3xf32> - return %0 : vector<3xf32> + return } // CHECK-LABEL: @fmul_vector4 -func @fmul_vector4(%arg: vector<4xf32>) -> vector<4xf32> { +func @fmul_vector4(%arg: vector<4xf32>) { // CHECK: spv.FMul %0 = mulf %arg, %arg : vector<4xf32> - return %0 : vector<4xf32> + return } // CHECK-LABEL: @fmul_vector5 -func @fmul_vector5(%arg: vector<5xf32>) -> vector<5xf32> { - // Vector length of only 2, 3, and 4 is valid for SPIR-V +func @fmul_vector5(%arg: vector<5xf32>) { + // Vector length of only 2, 3, and 4 is valid for SPIR-V. // CHECK: mulf %0 = mulf %arg, %arg : vector<5xf32> - return %0 : vector<5xf32> + return } -// CHECK-LABEL: @fmul_tensor -func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { - // For tensors mulf cannot be lowered directly to spv.FMul - // CHECK: mulf - %0 = mulf %arg, %arg : tensor<4xf32> - return %0 : tensor<4xf32> -} +// TODO(antiagainst): enable this once we support converting binary ops +// needing type conversion. +// XXXXX-LABEL: @fmul_tensor +//func @fmul_tensor(%arg: tensor<4xf32>) { + // For tensors mulf cannot be lowered directly to spv.FMul. + // XXXXX: mulf + //%0 = mulf %arg, %arg : tensor<4xf32> + //return +//} // CHECK-LABEL: @frem_scalar -func @frem_scalar(%arg: f32) -> f32 { +func @frem_scalar(%arg: f32) { // CHECK: spv.FRem %0 = remf %arg, %arg : f32 - return %0 : f32 + return } // CHECK-LABEL: @fsub_scalar -func @fsub_scalar(%arg: f32) -> f32 { +func @fsub_scalar(%arg: f32) { // CHECK: spv.FSub %0 = subf %arg, %arg : f32 - return %0 : f32 + return } // CHECK-LABEL: @div_rem @@ -306,7 +308,7 @@ // memref type //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) { +// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) func @memref_type(%arg0: memref<3xi1>) { return } diff --git a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir @@ -1,77 +1,77 @@ // RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @fmul(%arg0 : f32, %arg1 : f32) { + spv.func @fmul(%arg0 : f32, %arg1 : f32) "None" { // CHECK: {{%.*}}= spv.FMul {{%.*}}, {{%.*}} : f32 %0 = spv.FMul %arg0, %arg1 : f32 spv.Return } - func @fadd(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) { + spv.func @fadd(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : vector<4xf32> %0 = spv.FAdd %arg0, %arg1 : vector<4xf32> spv.Return } - func @fdiv(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) { + spv.func @fdiv(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : vector<4xf32> %0 = spv.FDiv %arg0, %arg1 : vector<4xf32> spv.Return } - func @fmod(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) { + spv.func @fmod(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { // CHECK: {{%.*}} = spv.FMod {{%.*}}, {{%.*}} : vector<4xf32> %0 = spv.FMod %arg0, %arg1 : vector<4xf32> spv.Return } - func @fnegate(%arg0 : vector<4xf32>) { + spv.func @fnegate(%arg0 : vector<4xf32>) "None" { // CHECK: {{%.*}} = spv.FNegate {{%.*}} : vector<4xf32> %0 = spv.FNegate %arg0 : vector<4xf32> spv.Return } - func @fsub(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) { + spv.func @fsub(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : vector<4xf32> %0 = spv.FSub %arg0, %arg1 : vector<4xf32> spv.Return } - func @frem(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) { + spv.func @frem(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { // CHECK: {{%.*}} = spv.FRem {{%.*}}, {{%.*}} : vector<4xf32> %0 = spv.FRem %arg0, %arg1 : vector<4xf32> spv.Return } - func @iadd(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + spv.func @iadd(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" { // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : vector<4xi32> %0 = spv.IAdd %arg0, %arg1 : vector<4xi32> spv.Return } - func @isub(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + spv.func @isub(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" { // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : vector<4xi32> %0 = spv.ISub %arg0, %arg1 : vector<4xi32> spv.Return } - func @imul(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + spv.func @imul(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" { // CHECK: {{%.*}} = spv.IMul {{%.*}}, {{%.*}} : vector<4xi32> %0 = spv.IMul %arg0, %arg1 : vector<4xi32> spv.Return } - func @udiv(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + spv.func @udiv(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" { // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : vector<4xi32> %0 = spv.UDiv %arg0, %arg1 : vector<4xi32> spv.Return } - func @umod(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + spv.func @umod(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" { // CHECK: {{%.*}} = spv.UMod {{%.*}}, {{%.*}} : vector<4xi32> %0 = spv.UMod %arg0, %arg1 : vector<4xi32> spv.Return } - func @sdiv(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + spv.func @sdiv(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" { // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : vector<4xi32> %0 = spv.SDiv %arg0, %arg1 : vector<4xi32> spv.Return } - func @smod(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + spv.func @smod(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" { // CHECK: {{%.*}} = spv.SMod {{%.*}}, {{%.*}} : vector<4xi32> %0 = spv.SMod %arg0, %arg1 : vector<4xi32> spv.Return } - func @srem(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + spv.func @srem(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) "None" { // CHECK: {{%.*}} = spv.SRem {{%.*}}, {{%.*}} : vector<4xi32> %0 = spv.SRem %arg0, %arg1 : vector<4xi32> spv.Return diff --git a/mlir/test/Dialect/SPIRV/Serialization/array.mlir b/mlir/test/Dialect/SPIRV/Serialization/array.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/array.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/array.mlir @@ -1,8 +1,7 @@ // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @array_stride(%arg0 : !spv.ptr [128]>, StorageBuffer>, - %arg1 : i32, %arg2 : i32) { + spv.func @array_stride(%arg0 : !spv.ptr [128]>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [128]>, StorageBuffer> %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr [128]>, StorageBuffer> spv.Return diff --git a/mlir/test/Dialect/SPIRV/Serialization/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/atomic-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/atomic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/atomic-ops.mlir @@ -2,7 +2,7 @@ spv.module "Logical" "GLSL450" { // CHECK-LABEL: @atomic_compare_exchange_weak - func @atomic_compare_exchange_weak(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 { + spv.func @atomic_compare_exchange_weak(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 "None" { // CHECK: spv.AtomicCompareExchangeWeak "Workgroup" "Release" "Acquire" %{{.*}}, %{{.*}}, %{{.*}} : !spv.ptr %0 = spv.AtomicCompareExchangeWeak "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr // CHECK: spv.AtomicAnd "Device" "None" %{{.*}}, %{{.*}} : !spv.ptr diff --git a/mlir/test/Dialect/SPIRV/Serialization/barrier.mlir b/mlir/test/Dialect/SPIRV/Serialization/barrier.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/barrier.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/barrier.mlir @@ -1,22 +1,22 @@ // RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @memory_barrier_0() -> () { + spv.func @memory_barrier_0() -> () "None" { // CHECK: spv.MemoryBarrier "Device", "Release|UniformMemory" spv.MemoryBarrier "Device", "Release|UniformMemory" spv.Return } - func @memory_barrier_1() -> () { + spv.func @memory_barrier_1() -> () "None" { // CHECK: spv.MemoryBarrier "Subgroup", "AcquireRelease|SubgroupMemory" spv.MemoryBarrier "Subgroup", "AcquireRelease|SubgroupMemory" spv.Return } - func @control_barrier_0() -> () { + spv.func @control_barrier_0() -> () "None" { // CHECK: spv.ControlBarrier "Device", "Workgroup", "Release|UniformMemory" spv.ControlBarrier "Device", "Workgroup", "Release|UniformMemory" spv.Return } - func @control_barrier_1() -> () { + spv.func @control_barrier_1() -> () "None" { // CHECK: spv.ControlBarrier "Workgroup", "Invocation", "AcquireRelease|UniformMemory" spv.ControlBarrier "Workgroup", "Invocation", "AcquireRelease|UniformMemory" spv.Return diff --git a/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir @@ -1,37 +1,37 @@ // RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @bitcount(%arg: i32) -> i32 { + spv.func @bitcount(%arg: i32) -> i32 "None" { // CHECK: spv.BitCount {{%.*}} : i32 %0 = spv.BitCount %arg : i32 spv.ReturnValue %0 : i32 } - func @bit_field_insert(%base: vector<3xi32>, %insert: vector<3xi32>, %offset: i32, %count: i16) -> vector<3xi32> { + spv.func @bit_field_insert(%base: vector<3xi32>, %insert: vector<3xi32>, %offset: i32, %count: i16) -> vector<3xi32> "None" { // CHECK: {{%.*}} = spv.BitFieldInsert {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<3xi32>, i32, i16 %0 = spv.BitFieldInsert %base, %insert, %offset, %count : vector<3xi32>, i32, i16 spv.ReturnValue %0 : vector<3xi32> } - func @bit_field_s_extract(%base: vector<3xi32>, %offset: i8, %count: i8) -> vector<3xi32> { + spv.func @bit_field_s_extract(%base: vector<3xi32>, %offset: i8, %count: i8) -> vector<3xi32> "None" { // CHECK: {{%.*}} = spv.BitFieldSExtract {{%.*}}, {{%.*}}, {{%.*}} : vector<3xi32>, i8, i8 %0 = spv.BitFieldSExtract %base, %offset, %count : vector<3xi32>, i8, i8 spv.ReturnValue %0 : vector<3xi32> } - func @bit_field_u_extract(%base: vector<3xi32>, %offset: i8, %count: i8) -> vector<3xi32> { + spv.func @bit_field_u_extract(%base: vector<3xi32>, %offset: i8, %count: i8) -> vector<3xi32> "None" { // CHECK: {{%.*}} = spv.BitFieldUExtract {{%.*}}, {{%.*}}, {{%.*}} : vector<3xi32>, i8, i8 %0 = spv.BitFieldUExtract %base, %offset, %count : vector<3xi32>, i8, i8 spv.ReturnValue %0 : vector<3xi32> } - func @bitreverse(%arg: i32) -> i32 { + spv.func @bitreverse(%arg: i32) -> i32 "None" { // CHECK: spv.BitReverse {{%.*}} : i32 %0 = spv.BitReverse %arg : i32 spv.ReturnValue %0 : i32 } - func @not(%arg: i32) -> i32 { + spv.func @not(%arg: i32) -> i32 "None" { // CHECK: spv.Not {{%.*}} : i32 %0 = spv.Not %arg : i32 spv.ReturnValue %0 : i32 } - func @bitwise_scalar(%arg0 : i32, %arg1 : i32) { + spv.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) "None" { // CHECK: spv.BitwiseAnd %0 = spv.BitwiseAnd %arg0, %arg1 : i32 // CHECK: spv.BitwiseOr @@ -40,17 +40,17 @@ %2 = spv.BitwiseXor %arg0, %arg1 : i32 spv.Return } - func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 { + spv.func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 "None" { // CHECK: {{%.*}} = spv.ShiftLeftLogical {{%.*}}, {{%.*}} : i32, i16 %0 = spv.ShiftLeftLogical %arg0, %arg1: i32, i16 spv.ReturnValue %0 : i32 } - func @shift_right_arithmetic(%arg0: vector<4xi32>, %arg1 : vector<4xi8>) -> vector<4xi32> { + spv.func @shift_right_arithmetic(%arg0: vector<4xi32>, %arg1 : vector<4xi8>) -> vector<4xi32> "None" { // CHECK: {{%.*}} = spv.ShiftRightArithmetic {{%.*}}, {{%.*}} : vector<4xi32>, vector<4xi8> %0 = spv.ShiftRightArithmetic %arg0, %arg1: vector<4xi32>, vector<4xi8> spv.ReturnValue %0 : vector<4xi32> } - func @shift_right_logical(%arg0: vector<2xi32>, %arg1 : vector<2xi8>) -> vector<2xi32> { + spv.func @shift_right_logical(%arg0: vector<2xi32>, %arg1 : vector<2xi8>) -> vector<2xi32> "None" { // CHECK: {{%.*}} = spv.ShiftRightLogical {{%.*}}, {{%.*}} : vector<2xi32>, vector<2xi8> %0 = spv.ShiftRightLogical %arg0, %arg1: vector<2xi32>, vector<2xi8> spv.ReturnValue %0 : vector<2xi32> diff --git a/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir @@ -1,7 +1,7 @@ // RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @bit_cast(%arg0 : f32) { + spv.func @bit_cast(%arg0 : f32) "None" { // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : f32 to i32 %0 = spv.Bitcast %arg0 : f32 to i32 spv.Return @@ -11,37 +11,37 @@ // ----- spv.module "Logical" "GLSL450" { - func @convert_f_to_s(%arg0 : f32) -> i32 { + spv.func @convert_f_to_s(%arg0 : f32) -> i32 "None" { // CHECK: {{%.*}} = spv.ConvertFToS {{%.*}} : f32 to i32 %0 = spv.ConvertFToS %arg0 : f32 to i32 spv.ReturnValue %0 : i32 } - func @convert_f_to_u(%arg0 : f32) -> i32 { + spv.func @convert_f_to_u(%arg0 : f32) -> i32 "None" { // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : f32 to i32 %0 = spv.ConvertFToU %arg0 : f32 to i32 spv.ReturnValue %0 : i32 } - func @convert_s_to_f(%arg0 : i32) -> f32 { + spv.func @convert_s_to_f(%arg0 : i32) -> f32 "None" { // CHECK: {{%.*}} = spv.ConvertSToF {{%.*}} : i32 to f32 %0 = spv.ConvertSToF %arg0 : i32 to f32 spv.ReturnValue %0 : f32 } - func @convert_u_to_f(%arg0 : i32) -> f32 { + spv.func @convert_u_to_f(%arg0 : i32) -> f32 "None" { // CHECK: {{%.*}} = spv.ConvertUToF {{%.*}} : i32 to f32 %0 = spv.ConvertUToF %arg0 : i32 to f32 spv.ReturnValue %0 : f32 } - func @f_convert(%arg0 : f32) -> f64 { + spv.func @f_convert(%arg0 : f32) -> f64 "None" { // CHECK: {{%.*}} = spv.FConvert {{%.*}} : f32 to f64 %0 = spv.FConvert %arg0 : f32 to f64 spv.ReturnValue %0 : f64 } - func @s_convert(%arg0 : i32) -> i64 { + spv.func @s_convert(%arg0 : i32) -> i64 "None" { // CHECK: {{%.*}} = spv.SConvert {{%.*}} : i32 to i64 %0 = spv.SConvert %arg0 : i32 to i64 spv.ReturnValue %0 : i64 } - func @u_convert(%arg0 : i32) -> i64 { + spv.func @u_convert(%arg0 : i32) -> i64 "None" { // CHECK: {{%.*}} = spv.UConvert {{%.*}} : i32 to i64 %0 = spv.UConvert %arg0 : i32 to i64 spv.ReturnValue %0 : i64 diff --git a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir @@ -1,12 +1,12 @@ // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @composite_insert(%arg0 : !spv.struct, f32>>, %arg1: !spv.array<4xf32>) -> !spv.struct, f32>> { + spv.func @composite_insert(%arg0 : !spv.struct, f32>>, %arg1: !spv.array<4xf32>) -> !spv.struct, f32>> "None" { // CHECK: spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct, f32>> %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct, f32>> spv.ReturnValue %0: !spv.struct, f32>> } - func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { + spv.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> "None" { // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32> %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32> spv.ReturnValue %0: vector<3xf32> diff --git a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir @@ -2,7 +2,7 @@ spv.module "Logical" "GLSL450" { // CHECK-LABEL: @bool_const - func @bool_const() -> () { + spv.func @bool_const() -> () "None" { // CHECK: spv.constant true %0 = spv.constant true // CHECK: spv.constant false @@ -14,7 +14,7 @@ } // CHECK-LABEL: @i32_const - func @i32_const() -> () { + spv.func @i32_const() -> () "None" { // CHECK: spv.constant 0 : i32 %0 = spv.constant 0 : i32 // CHECK: spv.constant 10 : i32 @@ -28,7 +28,7 @@ } // CHECK-LABEL: @i64_const - func @i64_const() -> () { + spv.func @i64_const() -> () "None" { // CHECK: spv.constant 4294967296 : i64 %0 = spv.constant 4294967296 : i64 // 2^32 // CHECK: spv.constant -4294967296 : i64 @@ -44,7 +44,7 @@ } // CHECK-LABEL: @i16_const - func @i16_const() -> () { + spv.func @i16_const() -> () "None" { // CHECK: spv.constant -32768 : i16 %0 = spv.constant -32768 : i16 // -2^15 // CHECK: spv.constant 32767 : i16 @@ -55,7 +55,7 @@ } // CHECK-LABEL: @float_const - func @float_const() -> () { + spv.func @float_const() -> () "None" { // CHECK: spv.constant 0.000000e+00 : f32 %0 = spv.constant 0. : f32 // CHECK: spv.constant 1.000000e+00 : f32 @@ -76,7 +76,7 @@ } // CHECK-LABEL: @double_const - func @double_const() -> () { + spv.func @double_const() -> () "None" { // TODO(antiagainst): test range boundary values // CHECK: spv.constant 1.024000e+03 : f64 %0 = spv.constant 1024. : f64 @@ -88,7 +88,7 @@ } // CHECK-LABEL: @half_const - func @half_const() -> () { + spv.func @half_const() -> () "None" { // CHECK: spv.constant 5.120000e+02 : f16 %0 = spv.constant 512. : f16 // CHECK: spv.constant -5.120000e+02 : f16 @@ -99,7 +99,7 @@ } // CHECK-LABEL: @bool_vector_const - func @bool_vector_const() -> () { + spv.func @bool_vector_const() -> () "None" { // CHECK: spv.constant dense : vector<2xi1> %0 = spv.constant dense : vector<2xi1> // CHECK: spv.constant dense<[true, true, true]> : vector<3xi1> @@ -114,7 +114,7 @@ } // CHECK-LABEL: @int_vector_const - func @int_vector_const() -> () { + spv.func @int_vector_const() -> () "None" { // CHECK: spv.constant dense<0> : vector<3xi32> %0 = spv.constant dense<0> : vector<3xi32> // CHECK: spv.constant dense<1> : vector<3xi32> @@ -128,7 +128,7 @@ } // CHECK-LABEL: @fp_vector_const - func @fp_vector_const() -> () { + spv.func @fp_vector_const() -> () "None" { // CHECK: spv.constant dense<0.000000e+00> : vector<4xf32> %0 = spv.constant dense<0.> : vector<4xf32> // CHECK: spv.constant dense<-1.500000e+01> : vector<4xf32> @@ -142,7 +142,7 @@ } // CHECK-LABEL: @array_const - func @array_const() -> (!spv.array<2 x vector<2xf32>>) { + spv.func @array_const() -> (!spv.array<2 x vector<2xf32>>) "None" { // CHECK: spv.constant [dense<3.000000e+00> : vector<2xf32>, dense<[4.000000e+00, 5.000000e+00]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>> %0 = spv.constant [dense<3.0> : vector<2xf32>, dense<[4., 5.]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>> @@ -150,14 +150,14 @@ } // CHECK-LABEL: @ignore_not_used_const - func @ignore_not_used_const() -> () { + spv.func @ignore_not_used_const() -> () "None" { %0 = spv.constant false // CHECK-NEXT: spv.Return spv.Return } // CHECK-LABEL: @materialize_const_at_each_use - func @materialize_const_at_each_use() -> (i32) { + spv.func @materialize_const_at_each_use() -> (i32) "None" { // CHECK: %[[USE1:.*]] = spv.constant 42 : i32 // CHECK: %[[USE2:.*]] = spv.constant 42 : i32 // CHECK: spv.IAdd %[[USE1]], %[[USE2]] @@ -167,7 +167,7 @@ } // CHECK-LABEL: @const_variable - func @const_variable(%arg0 : i32, %arg1 : i32) -> () { + spv.func @const_variable(%arg0 : i32, %arg1 : i32) -> () "None" { // CHECK: %[[CONST:.*]] = spv.constant 5 : i32 // CHECK: spv.Variable init(%[[CONST]]) : !spv.ptr // CHECK: spv.IAdd %arg0, %arg1 @@ -180,14 +180,14 @@ } // CHECK-LABEL: @multi_dimensions_const - func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) { + spv.func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) "None" { // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]], {{\[}}[7 : i32, 8 : i32, 9 : i32], [10 : i32, 11 : i32, 12 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> %0 = spv.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> } // CHECK-LABEL: @multi_dimensions_splat_const - func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) { + spv.func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) "None" { // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]], {{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> %0 = spv.constant dense<1> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> diff --git a/mlir/test/Dialect/SPIRV/Serialization/entry-point.mlir b/mlir/test/Dialect/SPIRV/Serialization/entry-point.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/entry-point.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/entry-point.mlir @@ -1,7 +1,7 @@ // RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @noop() -> () { + spv.func @noop() -> () "None" { spv.Return } // CHECK: spv.EntryPoint "GLCompute" @noop @@ -15,11 +15,11 @@ spv.module "Logical" "GLSL450" { // CHECK: spv.globalVariable @var2 : !spv.ptr // CHECK-NEXT: spv.globalVariable @var3 : !spv.ptr - // CHECK-NEXT: func @noop({{%.*}}: !spv.ptr, {{%.*}}: !spv.ptr) + // CHECK-NEXT: spv.func @noop({{%.*}}: !spv.ptr, {{%.*}}: !spv.ptr) "None" // CHECK: spv.EntryPoint "GLCompute" @noop, @var2, @var3 spv.globalVariable @var2 : !spv.ptr spv.globalVariable @var3 : !spv.ptr - func @noop(%arg0 : !spv.ptr, %arg1 : !spv.ptr) -> () { + spv.func @noop(%arg0 : !spv.ptr, %arg1 : !spv.ptr) -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @noop, @var2, @var3 diff --git a/mlir/test/Dialect/SPIRV/Serialization/execution-mode.mlir b/mlir/test/Dialect/SPIRV/Serialization/execution-mode.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/execution-mode.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/execution-mode.mlir @@ -1,7 +1,7 @@ // RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @foo() -> () { + spv.func @foo() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @foo diff --git a/mlir/test/Dialect/SPIRV/Serialization/function-call.mlir b/mlir/test/Dialect/SPIRV/Serialization/function-call.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/function-call.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/function-call.mlir @@ -2,7 +2,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @var1 : !spv.ptr, Input> - func @fmain() -> i32 { + spv.func @fmain() -> i32 "None" { %0 = spv.constant 16 : i32 %1 = spv._address_of @var1 : !spv.ptr, Input> // CHECK: {{%.*}} = spv.FunctionCall @f_0({{%.*}}) : (i32) -> i32 @@ -13,17 +13,17 @@ %4 = spv.FunctionCall @f_2(%1) : (!spv.ptr, Input>) -> !spv.ptr, Input> spv.ReturnValue %3 : i32 } - func @f_0(%arg0 : i32) -> i32 { + spv.func @f_0(%arg0 : i32) -> i32 "None" { spv.ReturnValue %arg0 : i32 } - func @f_1(%arg0 : i32, %arg1 : !spv.ptr, Input>) -> () { + spv.func @f_1(%arg0 : i32, %arg1 : !spv.ptr, Input>) -> () "None" { spv.Return } - func @f_2(%arg0 : !spv.ptr, Input>) -> !spv.ptr, Input> { + spv.func @f_2(%arg0 : !spv.ptr, Input>) -> !spv.ptr, Input> "None" { spv.ReturnValue %arg0 : !spv.ptr, Input> } - func @f_loop_with_function_call(%count : i32) -> () { + spv.func @f_loop_with_function_call(%count : i32) -> () "None" { %zero = spv.constant 0: i32 %var = spv.Variable init(%zero) : !spv.ptr spv.loop { @@ -43,7 +43,7 @@ } spv.Return } - func @f_inc(%arg0 : !spv.ptr) -> () { + spv.func @f_inc(%arg0 : !spv.ptr) -> () "None" { %one = spv.constant 1 : i32 %0 = spv.Load "Function" %arg0 : i32 %1 = spv.IAdd %0, %one : i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir b/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir @@ -25,7 +25,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @globalInvocationID built_in("GlobalInvocationId") : !spv.ptr, Input> - func @foo() { + spv.func @foo() "None" { // CHECK: %[[ADDR:.*]] = spv._address_of @globalInvocationID : !spv.ptr, Input> %0 = spv._address_of @globalInvocationID : !spv.ptr, Input> %1 = spv.constant 0: i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/glsl-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/glsl-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/glsl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/glsl-ops.mlir @@ -1,7 +1,7 @@ // RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @fmul(%arg0 : f32, %arg1 : f32) { + spv.func @fmul(%arg0 : f32, %arg1 : f32) "None" { // CHECK: {{%.*}} = spv.GLSL.Exp {{%.*}} : f32 %0 = spv.GLSL.Exp %arg0 : f32 // CHECK: {{%.*}} = spv.GLSL.FMax {{%.*}}, {{%.*}} : f32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir @@ -2,7 +2,7 @@ spv.module "Logical" "GLSL450" { // CHECK-LABEL: @subgroup_ballot - func @subgroup_ballot(%predicate: i1) -> vector<4xi32> { + spv.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> "None" { // CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}}: vector<4xi32> %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32> spv.ReturnValue %0: vector<4xi32> diff --git a/mlir/test/Dialect/SPIRV/Serialization/logical-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/logical-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/logical-ops.mlir @@ -1,57 +1,57 @@ // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @iequal_scalar(%arg0: i32, %arg1: i32) { + spv.func @iequal_scalar(%arg0: i32, %arg1: i32) "None" { // CHECK: {{.*}} = spv.IEqual {{.*}}, {{.*}} : i32 %0 = spv.IEqual %arg0, %arg1 : i32 spv.Return } - func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + spv.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) "None" { // CHECK: {{.*}} = spv.INotEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.INotEqual %arg0, %arg1 : vector<4xi32> spv.Return } - func @sgt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + spv.func @sgt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) "None" { // CHECK: {{.*}} = spv.SGreaterThan {{.*}}, {{.*}} : vector<4xi32> %0 = spv.SGreaterThan %arg0, %arg1 : vector<4xi32> spv.Return } - func @sge_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + spv.func @sge_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) "None" { // CHECK: {{.*}} = spv.SGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.SGreaterThanEqual %arg0, %arg1 : vector<4xi32> spv.Return } - func @slt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + spv.func @slt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) "None" { // CHECK: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : vector<4xi32> %0 = spv.SLessThan %arg0, %arg1 : vector<4xi32> spv.Return } - func @slte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + spv.func @slte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) "None" { // CHECK: {{.*}} = spv.SLessThanEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.SLessThanEqual %arg0, %arg1 : vector<4xi32> spv.Return } - func @ugt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + spv.func @ugt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) "None" { // CHECK: {{.*}} = spv.UGreaterThan {{.*}}, {{.*}} : vector<4xi32> %0 = spv.UGreaterThan %arg0, %arg1 : vector<4xi32> spv.Return } - func @ugte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + spv.func @ugte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) "None" { // CHECK: {{.*}} = spv.UGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.UGreaterThanEqual %arg0, %arg1 : vector<4xi32> spv.Return } - func @ult_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + spv.func @ult_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) "None" { // CHECK: {{.*}} = spv.ULessThan {{.*}}, {{.*}} : vector<4xi32> %0 = spv.ULessThan %arg0, %arg1 : vector<4xi32> spv.Return } - func @ulte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + spv.func @ulte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) "None" { // CHECK: {{.*}} = spv.ULessThanEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.ULessThanEqual %arg0, %arg1 : vector<4xi32> spv.Return } - func @cmpf(%arg0 : f32, %arg1 : f32) { + spv.func @cmpf(%arg0 : f32, %arg1 : f32) "None" { // CHECK: spv.FOrdEqual %1 = spv.FOrdEqual %arg0, %arg1 : f32 // CHECK: spv.FOrdGreaterThan @@ -84,7 +84,7 @@ spv.module "Logical" "GLSL450" { spv.specConstant @condition_scalar = true - func @select() -> () { + spv.func @select() -> () "None" { %0 = spv.constant 4.0 : f32 %1 = spv.constant 5.0 : f32 %2 = spv._reference_of @condition_scalar : i1 diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -4,7 +4,7 @@ spv.module "Logical" "GLSL450" { // for (int i = 0; i < count; ++i) {} - func @loop(%count : i32) -> () { + spv.func @loop(%count : i32) -> () "None" { %zero = spv.constant 0: i32 %one = spv.constant 1: i32 %var = spv.Variable init(%zero) : !spv.ptr @@ -51,7 +51,7 @@ spv.Return } - func @main() -> () { + spv.func @main() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @main @@ -64,7 +64,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @GV1 bind(0, 0) : !spv.ptr [0]>, StorageBuffer> spv.globalVariable @GV2 bind(0, 1) : !spv.ptr [0]>, StorageBuffer> - func @loop_kernel() { + spv.func @loop_kernel() "None" { %0 = spv._address_of @GV1 : !spv.ptr [0]>, StorageBuffer> %1 = spv.constant 0 : i32 %2 = spv.AccessChain %0[%1] : !spv.ptr [0]>, StorageBuffer> @@ -113,7 +113,7 @@ // for (int i = 0; i < count; ++i) { // for (int j = 0; j < count; ++j) { } // } - func @loop(%count : i32) -> () { + spv.func @loop(%count : i32) -> () "None" { %zero = spv.constant 0: i32 %one = spv.constant 1: i32 %ivar = spv.Variable init(%zero) : !spv.ptr @@ -203,7 +203,7 @@ spv.Return } - func @main() -> () { + spv.func @main() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @main diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -1,11 +1,11 @@ // RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s -// CHECK: func {{@.*}}([[ARG1:%.*]]: !spv.ptr, [[ARG2:%.*]]: !spv.ptr) { +// CHECK: spv.func {{@.*}}([[ARG1:%.*]]: !spv.ptr, [[ARG2:%.*]]: !spv.ptr) "None" { // CHECK-NEXT: [[VALUE:%.*]] = spv.Load "Input" [[ARG1]] : f32 // CHECK-NEXT: spv.Store "Output" [[ARG2]], [[VALUE]] : f32 spv.module "Logical" "GLSL450" { - func @load_store(%arg0 : !spv.ptr, %arg1 : !spv.ptr) { + spv.func @load_store(%arg0 : !spv.ptr, %arg1 : !spv.ptr) "None" { %1 = spv.Load "Input" %arg0 : f32 spv.Store "Output" %arg1, %1 : f32 spv.Return @@ -15,8 +15,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @access_chain(%arg0 : !spv.ptr>, Function>, - %arg1 : i32, %arg2 : i32) { + spv.func @access_chain(%arg0 : !spv.ptr>, Function>, %arg1 : i32, %arg2 : i32) "None" { // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr>, Function> // CHECK-NEXT: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr>, Function> %1 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, Function> diff --git a/mlir/test/Dialect/SPIRV/Serialization/module.mlir b/mlir/test/Dialect/SPIRV/Serialization/module.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/module.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/module.mlir @@ -1,13 +1,13 @@ // RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s // CHECK: spv.module "Logical" "GLSL450" { -// CHECK-NEXT: func @foo() { +// CHECK-NEXT: spv.func @foo() "None" { // CHECK-NEXT: spv.Return // CHECK-NEXT: } // CHECK-NEXT: } attributes {major_version = 1 : i32, minor_version = 0 : i32} spv.module "Logical" "GLSL450" { - func @foo() -> () { + spv.func @foo() -> () "None" { spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir @@ -2,56 +2,56 @@ spv.module "Logical" "GLSL450" { // CHECK-LABEL: @group_non_uniform_ballot - func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> { + spv.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> "None" { // CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32> %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32> spv.ReturnValue %0: vector<4xi32> } // CHECK-LABEL: @group_non_uniform_elect - func @group_non_uniform_elect() -> i1 { + spv.func @group_non_uniform_elect() -> i1 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1 %0 = spv.GroupNonUniformElect "Workgroup" : i1 spv.ReturnValue %0: i1 } // CHECK-LABEL: @group_non_uniform_fadd_reduce - func @group_non_uniform_fadd_reduce(%val: f32) -> f32 { + spv.func @group_non_uniform_fadd_reduce(%val: f32) -> f32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %{{.+}} : f32 %0 = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %val : f32 spv.ReturnValue %0: f32 } // CHECK-LABEL: @group_non_uniform_fmax_reduce - func @group_non_uniform_fmax_reduce(%val: f32) -> f32 { + spv.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformFMax "Workgroup" "Reduce" %{{.+}} : f32 %0 = spv.GroupNonUniformFMax "Workgroup" "Reduce" %val : f32 spv.ReturnValue %0: f32 } // CHECK-LABEL: @group_non_uniform_fmin_reduce - func @group_non_uniform_fmin_reduce(%val: f32) -> f32 { + spv.func @group_non_uniform_fmin_reduce(%val: f32) -> f32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformFMin "Workgroup" "Reduce" %{{.+}} : f32 %0 = spv.GroupNonUniformFMin "Workgroup" "Reduce" %val : f32 spv.ReturnValue %0: f32 } // CHECK-LABEL: @group_non_uniform_fmul_reduce - func @group_non_uniform_fmul_reduce(%val: f32) -> f32 { + spv.func @group_non_uniform_fmul_reduce(%val: f32) -> f32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformFMul "Workgroup" "Reduce" %{{.+}} : f32 %0 = spv.GroupNonUniformFMul "Workgroup" "Reduce" %val : f32 spv.ReturnValue %0: f32 } // CHECK-LABEL: @group_non_uniform_iadd_reduce - func @group_non_uniform_iadd_reduce(%val: i32) -> i32 { + spv.func @group_non_uniform_iadd_reduce(%val: i32) -> i32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %{{.+}} : i32 %0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %val : i32 spv.ReturnValue %0: i32 } // CHECK-LABEL: @group_non_uniform_iadd_clustered_reduce - func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> { + spv.func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> "None" { %four = spv.constant 4 : i32 // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xi32> %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32> @@ -59,35 +59,35 @@ } // CHECK-LABEL: @group_non_uniform_imul_reduce - func @group_non_uniform_imul_reduce(%val: i32) -> i32 { + spv.func @group_non_uniform_imul_reduce(%val: i32) -> i32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformIMul "Workgroup" "Reduce" %{{.+}} : i32 %0 = spv.GroupNonUniformIMul "Workgroup" "Reduce" %val : i32 spv.ReturnValue %0: i32 } // CHECK-LABEL: @group_non_uniform_smax_reduce - func @group_non_uniform_smax_reduce(%val: i32) -> i32 { + spv.func @group_non_uniform_smax_reduce(%val: i32) -> i32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformSMax "Workgroup" "Reduce" %{{.+}} : i32 %0 = spv.GroupNonUniformSMax "Workgroup" "Reduce" %val : i32 spv.ReturnValue %0: i32 } // CHECK-LABEL: @group_non_uniform_smin_reduce - func @group_non_uniform_smin_reduce(%val: i32) -> i32 { + spv.func @group_non_uniform_smin_reduce(%val: i32) -> i32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformSMin "Workgroup" "Reduce" %{{.+}} : i32 %0 = spv.GroupNonUniformSMin "Workgroup" "Reduce" %val : i32 spv.ReturnValue %0: i32 } // CHECK-LABEL: @group_non_uniform_umax_reduce - func @group_non_uniform_umax_reduce(%val: i32) -> i32 { + spv.func @group_non_uniform_umax_reduce(%val: i32) -> i32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformUMax "Workgroup" "Reduce" %{{.+}} : i32 %0 = spv.GroupNonUniformUMax "Workgroup" "Reduce" %val : i32 spv.ReturnValue %0: i32 } // CHECK-LABEL: @group_non_uniform_umin_reduce - func @group_non_uniform_umin_reduce(%val: i32) -> i32 { + spv.func @group_non_uniform_umin_reduce(%val: i32) -> i32 "None" { // CHECK: %{{.+}} = spv.GroupNonUniformUMin "Workgroup" "Reduce" %{{.+}} : i32 %0 = spv.GroupNonUniformUMin "Workgroup" "Reduce" %val : i32 spv.ReturnValue %0: i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/phi.mlir b/mlir/test/Dialect/SPIRV/Serialization/phi.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/phi.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/phi.mlir @@ -3,7 +3,7 @@ // Test branch with one block argument spv.module "Logical" "GLSL450" { - func @foo() -> () { + spv.func @foo() -> () "None" { // CHECK: %[[CST:.*]] = spv.constant 0 %zero = spv.constant 0 : i32 // CHECK-NEXT: spv.Branch ^bb1(%[[CST]] : i32) @@ -13,7 +13,7 @@ spv.Return } - func @main() -> () { + spv.func @main() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @main @@ -26,7 +26,7 @@ // Test branch with multiple block arguments spv.module "Logical" "GLSL450" { - func @foo() -> () { + spv.func @foo() -> () "None" { // CHECK: %[[ZERO:.*]] = spv.constant 0 %zero = spv.constant 0 : i32 // CHECK-NEXT: %[[ONE:.*]] = spv.constant 1 @@ -39,7 +39,7 @@ spv.Return } - func @main() -> () { + spv.func @main() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @main @@ -52,7 +52,7 @@ // Test using block arguments within branch spv.module "Logical" "GLSL450" { - func @foo() -> () { + spv.func @foo() -> () "None" { // CHECK: %[[CST0:.*]] = spv.constant 0 %zero = spv.constant 0 : i32 // CHECK-NEXT: spv.Branch ^bb1(%[[CST0]] : i32) @@ -71,7 +71,7 @@ spv.Return } - func @main() -> () { + spv.func @main() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @main @@ -84,7 +84,7 @@ // Test block not following domination order spv.module "Logical" "GLSL450" { - func @foo() -> () { + spv.func @foo() -> () "None" { // CHECK: spv.Branch ^bb1 spv.Branch ^bb1 @@ -105,7 +105,7 @@ spv.Branch ^bb2(%zero, %one : i32, f32) } - func @main() -> () { + spv.func @main() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @main @@ -118,7 +118,7 @@ // Test multiple predecessors spv.module "Logical" "GLSL450" { - func @foo() -> () { + spv.func @foo() -> () "None" { %var = spv.Variable : !spv.ptr // CHECK: spv.selection @@ -156,7 +156,7 @@ spv.Return } - func @main() -> () { + spv.func @main() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @main @@ -171,7 +171,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @__builtin_var_NumWorkgroups__ built_in("NumWorkgroups") : !spv.ptr, Input> spv.globalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spv.ptr, Input> - func @fmul_kernel() { + spv.func @fmul_kernel() "None" { %3 = spv.constant 12 : i32 %4 = spv.constant 32 : i32 %5 = spv.constant 4 : i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir @@ -3,7 +3,7 @@ // Selection with both then and else branches spv.module "Logical" "GLSL450" { - func @selection(%cond: i1) -> () { + spv.func @selection(%cond: i1) -> () "None" { // CHECK: spv.Branch ^bb1 // CHECK-NEXT: ^bb1: %zero = spv.constant 0: i32 @@ -43,7 +43,7 @@ spv.Return } - func @main() -> () { + spv.func @main() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @main @@ -58,8 +58,8 @@ // Selection in function entry block spv.module "Logical" "GLSL450" { -// CHECK: func @selection(%[[ARG:.*]]: i1 - func @selection(%cond: i1) -> (i32) { +// CHECK: spv.func @selection(%[[ARG:.*]]: i1 + spv.func @selection(%cond: i1) -> (i32) "None" { // CHECK: spv.Branch ^bb1 // CHECK-NEXT: ^bb1: // CHECK-NEXT: spv.selection @@ -82,7 +82,7 @@ spv.ReturnValue %one : i32 } - func @main() -> () { + spv.func @main() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @main diff --git a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir @@ -13,7 +13,7 @@ spv.specConstant @sc_float spec_id(5) = 1. : f32 // CHECK-LABEL: @use - func @use() -> (i32) { + spv.func @use() -> (i32) "None" { // We materialize a `spv._reference_of` op at every use of a // specialization constant in the deserializer. So two ops here. // CHECK: %[[USE1:.*]] = spv._reference_of @sc_int : i32 diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir @@ -27,7 +27,7 @@ // CHECK: !spv.ptr [0]>, Input>, // CHECK-SAME: !spv.ptr [0]>, Output> - func @kernel_1(%arg0: !spv.ptr [0]>, Input>, %arg1: !spv.ptr [0]>, Output>) -> () { + spv.func @kernel_1(%arg0: !spv.ptr [0]>, Input>, %arg1: !spv.ptr [0]>, Output>) -> () "None" { spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir b/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir @@ -2,13 +2,13 @@ spv.module "Logical" "GLSL450" { // CHECK-LABEL: @ret - func @ret() -> () { + spv.func @ret() -> () "None" { // CHECK: spv.Return spv.Return } // CHECK-LABEL: @ret_val - func @ret_val() -> (i32) { + spv.func @ret_val() -> (i32) "None" { %0 = spv.Variable : !spv.ptr %1 = spv.Load "Function" %0 : i32 // CHECK: spv.ReturnValue {{.*}} : i32 @@ -16,7 +16,7 @@ } // CHECK-LABEL: @unreachable - func @unreachable() { + spv.func @unreachable() "None" { spv.Return // CHECK-NOT: ^bb ^bb1: diff --git a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir @@ -1,7 +1,7 @@ // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @foo() -> () { + spv.func @foo() -> () "None" { // CHECK: {{%.*}} = spv.undef : f32 // CHECK-NEXT: {{%.*}} = spv.undef : f32 %0 = spv.undef : f32 @@ -24,10 +24,10 @@ // ----- spv.module "Logical" "GLSL450" { - // CHECK: func {{@.*}} - func @ignore_unused_undef() -> () { + // CHECK: spv.func {{@.*}} + spv.func @ignore_unused_undef() -> () "None" { // CHECK-NEXT: spv.Return %0 = spv.undef : f32 spv.Return } -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -17,35 +17,36 @@ // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr, StorageBuffer> // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr, StorageBuffer> // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr, StorageBuffer> - // CHECK: func [[FN:@.*]]() - func @load_store_kernel(%arg0: !spv.ptr>>, StorageBuffer> - {spv.interface_var_abi = {binding = 0 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, - %arg1: !spv.ptr>>, StorageBuffer> - {spv.interface_var_abi = {binding = 1 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, - %arg2: !spv.ptr>>, StorageBuffer> - {spv.interface_var_abi = {binding = 2 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, - %arg3: i32 - {spv.interface_var_abi = {binding = 3 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, - %arg4: i32 - {spv.interface_var_abi = {binding = 4 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, - %arg5: i32 - {spv.interface_var_abi = {binding = 5 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}, - %arg6: i32 - {spv.interface_var_abi = {binding = 6 : i32, - descriptor_set = 0 : i32, - storage_class = 12 : i32}}) + // CHECK: spv.func [[FN:@.*]]() + spv.func @load_store_kernel( + %arg0: !spv.ptr>>, StorageBuffer> + {spv.interface_var_abi = {binding = 0 : i32, + descriptor_set = 0 : i32, + storage_class = 12 : i32}}, + %arg1: !spv.ptr>>, StorageBuffer> + {spv.interface_var_abi = {binding = 1 : i32, + descriptor_set = 0 : i32, + storage_class = 12 : i32}}, + %arg2: !spv.ptr>>, StorageBuffer> + {spv.interface_var_abi = {binding = 2 : i32, + descriptor_set = 0 : i32, + storage_class = 12 : i32}}, + %arg3: i32 + {spv.interface_var_abi = {binding = 3 : i32, + descriptor_set = 0 : i32, + storage_class = 12 : i32}}, + %arg4: i32 + {spv.interface_var_abi = {binding = 4 : i32, + descriptor_set = 0 : i32, + storage_class = 12 : i32}}, + %arg5: i32 + {spv.interface_var_abi = {binding = 5 : i32, + descriptor_set = 0 : i32, + storage_class = 12 : i32}}, + %arg6: i32 + {spv.interface_var_abi = {binding = 6 : i32, + descriptor_set = 0 : i32, + storage_class = 12 : i32}}) "None" attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { // CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]] // CHECK: [[CONST6:%.*]] = spv.constant 0 : i32 diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir @@ -4,16 +4,16 @@ spv.module "Logical" "GLSL450" { // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr [0]>, StorageBuffer> - // CHECK: func [[FN:@.*]]() - func @kernel_1(%arg0: f32 + // CHECK: spv.func [[FN:@.*]]() + spv.func @kernel(%arg0: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}}, %arg1: !spv.ptr>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, - storage_class = 12 : i32}}) - attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { + storage_class = 12 : i32}}) "None" + attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]] // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]] // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32 diff --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir @@ -1,12 +1,12 @@ // RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline)' -mlir-disable-inline-simplify | FileCheck %s spv.module "Logical" "GLSL450" { - func @callee() { + spv.func @callee() "None" { spv.Return } - // CHECK-LABEL: func @calling_single_block_ret_func - func @calling_single_block_ret_func() { + // CHECK-LABEL: @calling_single_block_ret_func + spv.func @calling_single_block_ret_func() "None" { // CHECK-NEXT: spv.Return spv.FunctionCall @callee() : () -> () spv.Return @@ -16,13 +16,13 @@ // ----- spv.module "Logical" "GLSL450" { - func @callee() -> i32 { + spv.func @callee() -> i32 "None" { %0 = spv.constant 42 : i32 spv.ReturnValue %0 : i32 } - // CHECK-LABEL: func @calling_single_block_retval_func - func @calling_single_block_retval_func() -> i32 { + // CHECK-LABEL: @calling_single_block_retval_func + spv.func @calling_single_block_retval_func() -> i32 "None" { // CHECK-NEXT: %[[CST:.*]] = spv.constant 42 %0 = spv.FunctionCall @callee() : () -> (i32) // CHECK-NEXT: spv.ReturnValue %[[CST]] @@ -34,7 +34,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @data bind(0, 0) : !spv.ptr [0]>, StorageBuffer> - func @callee() { + spv.func @callee() "None" { %0 = spv._address_of @data : !spv.ptr [0]>, StorageBuffer> %1 = spv.constant 0: i32 %2 = spv.AccessChain %0[%1, %1] : !spv.ptr [0]>, StorageBuffer> @@ -46,8 +46,8 @@ spv.Return } - // CHECK-LABEL: func @calling_multi_block_ret_func - func @calling_multi_block_ret_func() { + // CHECK-LABEL: @calling_multi_block_ret_func + spv.func @calling_multi_block_ret_func() "None" { // CHECK-NEXT: spv._address_of // CHECK-NEXT: spv.constant 0 // CHECK-NEXT: spv.AccessChain @@ -68,7 +68,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @callee(%cond : i1) -> () { + spv.func @callee(%cond : i1) -> () "None" { spv.selection { spv.BranchConditional %cond, ^then, ^merge ^then: @@ -79,8 +79,8 @@ spv.Return } - // CHECK-LABEL: calling_selection_ret_func - func @calling_selection_ret_func() { + // CHECK-LABEL: @calling_selection_ret_func + spv.func @calling_selection_ret_func() "None" { %0 = spv.constant true // CHECK: spv.FunctionCall spv.FunctionCall @callee(%0) : (i1) -> () @@ -91,7 +91,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @callee(%cond : i1) -> () { + spv.func @callee(%cond : i1) -> () "None" { spv.selection { spv.BranchConditional %cond, ^then, ^merge ^then: @@ -102,8 +102,8 @@ spv.Return } - // CHECK-LABEL: calling_selection_no_ret_func - func @calling_selection_no_ret_func() { + // CHECK-LABEL: @calling_selection_no_ret_func + spv.func @calling_selection_no_ret_func() "None" { // CHECK-NEXT: %[[TRUE:.*]] = spv.constant true %0 = spv.constant true // CHECK-NEXT: spv.selection @@ -120,7 +120,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @callee(%cond : i1) -> () { + spv.func @callee(%cond : i1) -> () "None" { spv.loop { spv.Branch ^header ^header: @@ -135,8 +135,8 @@ spv.Return } - // CHECK-LABEL: calling_loop_ret_func - func @calling_loop_ret_func() { + // CHECK-LABEL: @calling_loop_ret_func + spv.func @calling_loop_ret_func() "None" { %0 = spv.constant true // CHECK: spv.FunctionCall spv.FunctionCall @callee(%0) : (i1) -> () @@ -147,7 +147,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @callee(%cond : i1) -> () { + spv.func @callee(%cond : i1) -> () "None" { spv.loop { spv.Branch ^header ^header: @@ -162,8 +162,8 @@ spv.Return } - // CHECK-LABEL: calling_loop_no_ret_func - func @calling_loop_no_ret_func() { + // CHECK-LABEL: @calling_loop_no_ret_func + spv.func @calling_loop_no_ret_func() "None" { // CHECK-NEXT: %[[TRUE:.*]] = spv.constant true %0 = spv.constant true // CHECK-NEXT: spv.loop @@ -186,8 +186,9 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @arg_0 bind(0, 0) : !spv.ptr, StorageBuffer> spv.globalVariable @arg_1 bind(0, 1) : !spv.ptr, StorageBuffer> - // CHECK: func @inline_into_selection_region - func @inline_into_selection_region() { + + // CHECK: @inline_into_selection_region + spv.func @inline_into_selection_region() "None" { %1 = spv.constant 0 : i32 // CHECK-DAG: [[ADDRESS_ARG0:%.*]] = spv._address_of @arg_0 // CHECK-DAG: [[ADDRESS_ARG1:%.*]] = spv._address_of @arg_1 @@ -215,7 +216,7 @@ // CHECK: spv.Return spv.Return } - func @atomic_add(%arg0: i32, %arg1: !spv.ptr) { + spv.func @atomic_add(%arg0: i32, %arg1: !spv.ptr) "None" { %0 = spv.AtomicIAdd "Device" "AcquireRelease" %arg1, %arg0 : !spv.ptr spv.Return } @@ -224,4 +225,4 @@ } attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]} // TODO: Add tests for inlining structured control flow into -// structured control flow. \ No newline at end of file +// structured control flow. diff --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir @@ -19,7 +19,7 @@ // CHECK: spv.globalVariable @var5 bind(1, 3) : !spv.ptr [0]>, StorageBuffer> spv.globalVariable @var5 bind(1,3) : !spv.ptr>, StorageBuffer> - func @kernel() -> () { + spv.func @kernel() -> () "None" { %c0 = spv.constant 0 : i32 // CHECK: {{%.*}} = spv._address_of @var0 : !spv.ptr [4], f32 [12]>, Uniform> %0 = spv._address_of @var0 : !spv.ptr, f32>, Uniform> diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -155,7 +155,7 @@ //===----------------------------------------------------------------------===// spv.module "Logical" "GLSL450" { - func @fmain(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>, %arg2 : i32) -> i32 { + spv.func @fmain(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>, %arg2 : i32) -> i32 "None" { // CHECK: {{%.*}} = spv.FunctionCall @f_0({{%.*}}, {{%.*}}) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32> %0 = spv.FunctionCall @f_0(%arg0, %arg1) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32> // CHECK: spv.FunctionCall @f_1({{%.*}}, {{%.*}}) : (vector<4xf32>, vector<4xf32>) -> () @@ -167,19 +167,19 @@ spv.ReturnValue %1 : i32 } - func @f_0(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> (vector<4xf32>) { + spv.func @f_0(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> (vector<4xf32>) "None" { spv.ReturnValue %arg0 : vector<4xf32> } - func @f_1(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> () { + spv.func @f_1(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> () "None" { spv.Return } - func @f_2() -> () { + spv.func @f_2() -> () "None" { spv.Return } - func @f_3(%arg0 : i32) -> (i32) { + spv.func @f_3(%arg0 : i32) -> (i32) "None" { spv.ReturnValue %arg0 : i32 } } @@ -187,7 +187,7 @@ // ----- // Allow calling functions in other module-like ops -func @callee() { +spv.func @callee() "None" { spv.Return } @@ -200,7 +200,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () { + spv.func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () "None" { // expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}} %0:2 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32) spv.Return @@ -210,7 +210,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @f_result_type_mismatch(%arg0 : i32, %arg1 : i32) -> () { + spv.func @f_result_type_mismatch(%arg0 : i32, %arg1 : i32) -> () "None" { // expected-error @+1 {{has incorrect number of results has for callee: expected 0, but provided 1}} %1 = spv.FunctionCall @f_result_type_mismatch(%arg0, %arg0) : (i32, i32) -> (i32) spv.Return @@ -220,7 +220,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () { + spv.func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () "None" { // expected-error @+1 {{has incorrect number of operands for callee: expected 2, but provided 1}} spv.FunctionCall @f_type_mismatch(%arg0) : (i32) -> () spv.Return @@ -230,7 +230,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () { + spv.func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () "None" { %0 = spv.constant 2.0 : f32 // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32' for operand number 1}} spv.FunctionCall @f_type_mismatch(%arg0, %0) : (i32, f32) -> () @@ -241,20 +241,21 @@ // ----- spv.module "Logical" "GLSL450" { - func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> i32 { + spv.func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> i32 "None" { + %cst = spv.constant 0: i32 // expected-error @+1 {{result type mismatch: expected 'i32', but provided 'f32'}} %0 = spv.FunctionCall @f_type_mismatch(%arg0, %arg0) : (i32, i32) -> f32 - spv.Return + spv.ReturnValue %cst: i32 } } // ----- spv.module "Logical" "GLSL450" { - func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 { + spv.func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 "None" { // expected-error @+1 {{op callee function 'f_undefined' not found in nearest symbol table}} %0 = spv.FunctionCall @f_undefined(%arg0, %arg0) : (i32, i32) -> i32 - spv.Return + spv.ReturnValue %0: i32 } } @@ -500,11 +501,9 @@ spv.Return } -// ----- - // CHECK-LABEL: in_other_func_like_op func @in_other_func_like_op() { - // CHECK: spv.Return + // CHECK: spv.Return spv.Return } @@ -519,7 +518,7 @@ // Return mismatches function signature spv.module "Logical" "GLSL450" { - func @work() -> (i32) { + spv.func @work() -> (i32) "None" { // expected-error @+1 {{cannot be used in functions returning value}} spv.Return } @@ -527,6 +526,24 @@ // ----- +spv.module "Logical" "GLSL450" { + spv.func @in_nested_region(%cond: i1) -> (i32) "None" { + spv.selection { + spv.BranchConditional %cond, ^then, ^merge + ^then: + // expected-error @+1 {{cannot be used in functions returning value}} + spv.Return + ^merge: + spv._merge + } + + %zero = spv.constant 0: i32 + spv.ReturnValue %zero: i32 + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.ReturnValue //===----------------------------------------------------------------------===// @@ -571,6 +588,12 @@ spv.ReturnValue %one : i32 } +// CHECK-LABEL: in_other_func_like_op +func @in_other_func_like_op(%arg: i32) -> i32 { + // CHECK: spv.ReturnValue + spv.ReturnValue %arg: i32 +} + // ----- "foo.function"() ({ @@ -581,18 +604,40 @@ // ----- -func @value_count_mismatch() -> () { - %0 = spv.constant 42 : i32 - // expected-error @+1 {{op returns 1 value but enclosing function requires 0 results}} - spv.ReturnValue %0 : i32 +spv.module "Logical" "GLSL450" { + spv.func @value_count_mismatch() -> () "None" { + %0 = spv.constant 42 : i32 + // expected-error @+1 {{op returns 1 value but enclosing function requires 0 results}} + spv.ReturnValue %0 : i32 + } } // ----- -func @value_type_mismatch() -> (f32) { - %0 = spv.constant 42 : i32 - // expected-error @+1 {{return value's type ('i32') mismatch with function's result type ('f32')}} - spv.ReturnValue %0 : i32 +spv.module "Logical" "GLSL450" { + spv.func @value_type_mismatch() -> (f32) "None" { + %0 = spv.constant 42 : i32 + // expected-error @+1 {{return value's type ('i32') mismatch with function's result type ('f32')}} + spv.ReturnValue %0 : i32 + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + spv.func @in_nested_region(%cond: i1) -> () "None" { + spv.selection { + spv.BranchConditional %cond, ^then, ^merge + ^then: + %cst = spv.constant 0: i32 + // expected-error @+1 {{op returns 1 value but enclosing function requires 0 results}} + spv.ReturnValue %cst: i32 + ^merge: + spv._merge + } + + spv.Return + } } // ----- diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -417,7 +417,7 @@ //===----------------------------------------------------------------------===// spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @do_nothing @@ -426,7 +426,7 @@ } spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @do_nothing @@ -437,7 +437,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @do_nothing @@ -642,7 +642,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @var0 : !spv.ptr // CHECK_LABEL: @simple_load - func @simple_load() -> () { + spv.func @simple_load() -> () "None" { // CHECK: spv.Load "Input" {{%.*}} : f32 %0 = spv._address_of @var0 : !spv.ptr %1 = spv.Load "Input" %0 : f32 @@ -1059,7 +1059,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @var0 : !spv.ptr - func @simple_store(%arg0 : f32) -> () { + spv.func @simple_store(%arg0 : f32) -> () "None" { %0 = spv._address_of @var0 : !spv.ptr // CHECK: spv.Store "Input" {{%.*}}, {{%.*}} : f32 spv.Store "Input" %0, %arg0 : f32 @@ -1132,7 +1132,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @global : !spv.ptr - func @variable_init_global_variable() -> () { + spv.func @variable_init_global_variable() -> () "None" { %0 = spv._address_of @global : !spv.ptr // CHECK: spv.Variable init({{.*}}) : !spv.ptr, Function> %1 = spv.Variable init(%0) : !spv.ptr, Function> @@ -1148,7 +1148,7 @@ spv.module "Logical" "GLSL450" { spv.specConstant @sc = 42 : i32 // CHECK-LABEL: @variable_init_spec_constant - func @variable_init_spec_constant() -> () { + spv.func @variable_init_spec_constant() -> () "None" { %0 = spv._reference_of @sc : i32 // CHECK: spv.Variable init(%0) : !spv.ptr %1 = spv.Variable init(%0) : !spv.ptr diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -6,7 +6,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @var1 : !spv.ptr>, Input> - func @access_chain() -> () { + spv.func @access_chain() -> () "None" { %0 = spv.constant 1: i32 // CHECK: [[VAR1:%.*]] = spv._address_of @var1 : !spv.ptr>, Input> // CHECK-NEXT: spv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spv.ptr>, Input> @@ -30,7 +30,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @var1 : !spv.ptr>, Input> - func @foo() -> () { + spv.func @foo() -> () "None" { // expected-error @+1 {{expected spv.globalVariable symbol}} %0 = spv._address_of @var2 : !spv.ptr>, Input> } @@ -40,7 +40,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @var1 : !spv.ptr>, Input> - func @foo() -> () { + spv.func @foo() -> () "None" { // expected-error @+1 {{result type mismatch with the referenced global variable's type}} %0 = spv._address_of @var1 : !spv.ptr } @@ -136,7 +136,7 @@ //===----------------------------------------------------------------------===// spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { spv.Return } // CHECK: spv.EntryPoint "GLCompute" @do_nothing @@ -146,7 +146,7 @@ spv.module "Logical" "GLSL450" { spv.globalVariable @var2 : !spv.ptr spv.globalVariable @var3 : !spv.ptr - func @do_something(%arg0 : !spv.ptr, %arg1 : !spv.ptr) -> () { + spv.func @do_something(%arg0 : !spv.ptr, %arg1 : !spv.ptr) -> () "None" { %1 = spv.Load "Input" %arg0 : f32 spv.Store "Output" %arg1, %1 : f32 spv.Return @@ -158,7 +158,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { spv.Return } // expected-error @+1 {{invalid kind of attribute specified}} @@ -168,7 +168,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { spv.Return } // expected-error @+1 {{function 'do_something' not found in 'spv.module'}} @@ -183,7 +183,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { // expected-error @+1 {{op must appear in a module-like op's block}} spv.EntryPoint "GLCompute" @do_something } @@ -192,7 +192,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @do_nothing @@ -203,7 +203,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { spv.Return } spv.EntryPoint "GLCompute" @do_nothing @@ -213,6 +213,55 @@ // ----- +//===----------------------------------------------------------------------===// +// spv.func +//===----------------------------------------------------------------------===// + +// CHECK: spv.func @foo() "None" +spv.func @foo() "None" + +// CHECK: spv.func @bar(%{{.+}}: i32) -> i32 "Inline|Pure" { +spv.func @bar(%arg: i32) -> (i32) "Inline|Pure" { + // CHECK-NEXT: spv. + spv.ReturnValue %arg: i32 +// CHECK-NEXT: } +} + +// CHECK: spv.func @baz(%{{.+}}: i32) "DontInline" attributes {additional_stuff = 64 : i64} +spv.func @baz(%arg: i32) "DontInline" attributes { + additional_stuff = 64 +} { spv.Return } + +// ----- + +// expected-error @+1 {{expected function_control attribute specified as string}} +spv.func @missing_function_control() { spv.Return } + +// ----- + +// expected-error @+1 {{cannot have more than one result}} +spv.func @cannot_have_more_than_one_result(%arg: i32) -> (i32, i32) "None" + +// ----- + +// expected-error @+1 {{expected SSA identifier}} +spv.func @cannot_have_variadic_arguments(%arg: i32, ...) "None" + +// ----- + +// Nested function +spv.module "Logical" "GLSL450" { + spv.func @outer_func() -> () "None" { + // expected-error @+1 {{must appear in a module-like op's block}} + spv.func @inner_func() -> () "None" { + spv.Return + } + spv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.globalVariable //===----------------------------------------------------------------------===// @@ -299,7 +348,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @foo() { + spv.func @foo() "None" { // expected-error @+1 {{op must appear in a module-like op's block}} spv.globalVariable @var0 : !spv.ptr spv.Return @@ -332,7 +381,7 @@ // Module with function // CHECK: spv.module spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { spv.Return } } @@ -383,9 +432,9 @@ // ----- -// Use non SPIR-V op inside.module +// Use non SPIR-V op inside module spv.module "Logical" "GLSL450" { - // expected-error @+1 {{'spv.module' can only contain func and spv.* ops}} + // expected-error @+1 {{'spv.module' can only contain spv.* ops}} "dialect.op"() : () -> () } @@ -393,7 +442,7 @@ // Use non SPIR-V op inside function spv.module "Logical" "GLSL450" { - func @do_nothing() -> () { + spv.func @do_nothing() -> () "None" { // expected-error @+1 {{functions in 'spv.module' can only contain spv.* ops}} "dialect.op"() : () -> () } @@ -404,20 +453,7 @@ // Use external function spv.module "Logical" "GLSL450" { // expected-error @+1 {{'spv.module' cannot contain external functions}} - func @extern() -> () -} - -// ----- - -// Module with nested function -spv.module "Logical" "GLSL450" { - func @outer_func() -> () { - // expected-error @+1 {{'spv.module' cannot contain nested functions}} - func @inner_func() -> () { - spv.Return - } - spv.Return - } + spv.func @extern() -> () "None" } // ----- @@ -459,14 +495,14 @@ spv.specConstant @sc3 = 1.5 : f32 // CHECK-LABEL: @reference - func @reference() -> i1 { + spv.func @reference() -> i1 "None" { // CHECK: spv._reference_of @sc1 : i1 %0 = spv._reference_of @sc1 : i1 spv.ReturnValue %0 : i1 } // CHECK-LABEL: @initialize - func @initialize() -> i64 { + spv.func @initialize() -> i64 "None" { // CHECK: spv._reference_of @sc2 : i64 %0 = spv._reference_of @sc2 : i64 %1 = spv.Variable init(%0) : !spv.ptr @@ -475,7 +511,7 @@ } // CHECK-LABEL: @compute - func @compute() -> f32 { + spv.func @compute() -> f32 "None" { // CHECK: spv._reference_of @sc3 : f32 %0 = spv._reference_of @sc3 : f32 %1 = spv.constant 6.0 : f32 @@ -497,7 +533,7 @@ // ----- spv.module "Logical" "GLSL450" { - func @foo() -> () { + spv.func @foo() -> () "None" { // expected-error @+1 {{expected spv.specConstant symbol}} %0 = spv._reference_of @sc : i32 spv.Return @@ -508,7 +544,7 @@ spv.module "Logical" "GLSL450" { spv.specConstant @sc = 42 : i32 - func @foo() -> () { + spv.func @foo() -> () "None" { // expected-error @+1 {{result type mismatch with the referenced specialization constant's type}} %0 = spv._reference_of @sc : f32 spv.Return