diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -491,6 +491,8 @@ ```mlir %0 = spv._reference_of @spec_const : f32 ``` + + TODO Add support for composite specialization constants. }]; let arguments = (ins @@ -541,8 +543,6 @@ spv.specConstant @spec_const1 = true spv.specConstant @spec_const2 spec_id(5) = 42 : i32 ``` - - TODO: support composite spec constants with another op }]; let arguments = (ins @@ -557,6 +557,56 @@ let autogenSerialization = 0; } +def SPV_SpecConstantCompositeOp : SPV_Op<"specConstantComposite", [InModuleScope, Symbol]> { + let summary = "Declare a new composite specialization constant."; + + let description = [{ + This op declares a SPIR-V composite specialization constant. This covers + the `OpSpecConstantComposite` SPIR-V instruction. Scalar constants are + covered by `spv.specConstant`. + + A constituent of a spec constant composite can be: + - A symbol referring of another spec constant. + - The SSA ID of a non-specialization constant (i.e. defined through + `spv.specConstant`). + - The SSA ID of a `spv.undef`. + + ``` + spv-spec-constant-composite-op ::= `spv.specConstantComposite` symbol-ref-id ` (` + symbol-ref-id (`, ` symbol-ref-id)* + `) :` composite-type + ``` + + where `composite-type` is some non-scalar type that can be represented in the `spv` + dialect: `spv.struct`, `spv.array`, `vector`, or `spv.coopmatrix`. + + #### Example: + + ```mlir + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct + ``` + + TODO Add support for constituents that are: + - regular constants. + - undef. + - spec constant composite. + }]; + + let arguments = (ins + TypeAttr:$type, + StrAttr:$sym_name, + SymbolRefArrayAttr:$constituents + ); + + let results = (outs); + + let hasOpcode = 0; + + let autogenSerialization = 0; +} // ----- #endif // SPIRV_STRUCTURE_OPS diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -53,6 +53,7 @@ static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics"; static constexpr const char kValueAttrName[] = "value"; static constexpr const char kValuesAttrName[] = "values"; +static constexpr const char kCompositeSpecConstituentsName[] = "constituents"; //===----------------------------------------------------------------------===// // Common utility functions @@ -3287,6 +3288,95 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.specConstantComposite +//===----------------------------------------------------------------------===// + +static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser, + OperationState &state) { + + StringAttr compositeName; + if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(), + state.attributes)) + return failure(); + + if (parser.parseLParen()) + return failure(); + + SmallVector constituents; + + do { + // The name of the constituent attribute isn't important + const char *attrName = "spec_const"; + FlatSymbolRefAttr specConstRef; + NamedAttrList attrs; + + if (parser.parseAttribute(specConstRef, Type(), attrName, attrs)) + return failure(); + + constituents.push_back(specConstRef); + } while (!parser.parseOptionalComma()); + + if (parser.parseRParen()) + return failure(); + + state.addAttribute(kCompositeSpecConstituentsName, + parser.getBuilder().getArrayAttr(constituents)); + + Type type; + if (parser.parseColonType(type)) + return failure(); + + state.addAttribute(kTypeAttrName, TypeAttr::get(type)); + + return success(); +} + +static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) { + printer << spirv::SpecConstantCompositeOp::getOperationName() << " "; + printer.printSymbolName(op.sym_name()); + printer << " ("; + auto constituents = op.constituents().getValue(); + + if (!constituents.empty()) + llvm::interleaveComma(constituents, printer); + + printer << ") : " << op.type(); +} + +static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) { + auto cType = constOp.type().dyn_cast(); + auto constituents = constOp.constituents().getValue(); + + if (!cType) + return constOp.emitError( + "result type must be a composite type, but provided ") + << constOp.type(); + + if (cType.isa()) + return constOp.emitError("unsupported composite type ") << cType; + else if (constituents.size() != cType.getNumElements()) + return constOp.emitError("has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " + << constituents.size(); + + for (auto index : llvm::seq(0, constituents.size())) { + auto constituent = constituents[index].dyn_cast(); + + auto constituentSpecConstOp = + dyn_cast(SymbolTable::lookupNearestSymbolFrom( + constOp.getParentOp(), constituent.getValue())); + + if (constituentSpecConstOp.default_value().getType() != + cType.getElementType(index)) + return constOp.emitError("has incorrect types of operands: expected ") + << cType.getElementType(index) << ", but provided " + << constituentSpecConstOp.default_value().getType(); + } + + return success(); +} + namespace mlir { namespace spirv { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -249,6 +249,8 @@ /// `operands`. LogicalResult processConstantComposite(ArrayRef operands); + LogicalResult processSpecConstantComposite(ArrayRef operands); + /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); @@ -1546,6 +1548,39 @@ return success(); } +LogicalResult +Deserializer::processSpecConstantComposite(ArrayRef operands) { + if (operands.size() < 2) { + return emitError(unknownLoc, + "OpConstantComposite must have type and result "); + } + if (operands.size() < 3) { + return emitError(unknownLoc, + "OpConstantComposite must have at least 1 parameter"); + } + + Type resultType = getType(operands[0]); + if (!resultType) { + return emitError(unknownLoc, "undefined result type from ") + << operands[0]; + } + + auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1])); + + SmallVector elements; + elements.reserve(operands.size() - 2); + for (unsigned i = 2, e = operands.size(); i < e; ++i) { + auto elementInfo = getSpecConstant(operands[i]); + elements.push_back(opBuilder.getSymbolRefAttr(elementInfo)); + } + + opBuilder.create( + unknownLoc, TypeAttr::get(resultType), symName, + opBuilder.getArrayAttr(elements)); + + return success(); +} + LogicalResult Deserializer::processConstantNull(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, @@ -2276,6 +2311,8 @@ return processConstant(operands, /*isSpec=*/true); case spirv::Opcode::OpConstantComposite: return processConstantComposite(operands); + case spirv::Opcode::OpSpecConstantComposite: + return processSpecConstantComposite(operands); case spirv::Opcode::OpConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantTrue: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -200,6 +200,9 @@ LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); + LogicalResult + processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); + /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA /// value to use with other operations. The SPIR-V spec recommends that /// OpUndef be generated at module level. The serialization generates an @@ -645,6 +648,42 @@ return failure(); } +LogicalResult +Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { + uint32_t typeID = 0; + if (failed(processType(op.getLoc(), op.type(), typeID))) { + return failure(); + } + + auto resultID = getNextID(); + + SmallVector operands; + operands.push_back(typeID); + operands.push_back(resultID); + + auto constituents = op.constituents(); + + for (auto index : llvm::seq(0, constituents.size())) { + auto constituent = constituents[index].dyn_cast(); + + auto constituentName = constituent.getValue(); + auto constituentID = getSpecConstID(constituentName); + + if (!constituentID) { + return op.emitError("unknown result for specialization constant ") + << constituentName; + } + + operands.push_back(constituentID); + } + + encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpSpecConstantComposite, operands); + specConstIDMap[op.sym_name()] = resultID; + + return processName(resultID, op.sym_name()); +} + LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { auto undefType = op.getType(); auto &id = undefValIDMap[undefType]; @@ -1765,6 +1804,9 @@ .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) + .Case([&](spirv::SpecConstantCompositeOp op) { + return processSpecConstantCompositeOp(op); + }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) diff --git a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s +// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { // CHECK: spv.specConstant @sc_true = true @@ -25,3 +25,23 @@ spv.ReturnValue %1 : i32 } } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + + spv.specConstant @sc_f32_1 = 1.5 : f32 + spv.specConstant @sc_f32_2 = 2.5 : f32 + spv.specConstant @sc_f32_3 = 3.5 : f32 + + spv.specConstant @sc_i32_1 = 1 : i32 + + // CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32> + spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32> + + // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct + spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct + + // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32> + spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32> +} diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -596,3 +596,130 @@ spv.specConstant @sc = false return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite +//===----------------------------------------------------------------------===// + +spv.module Logical GLSL450 { + // expected-error @+1 {{result type must be a composite type}} + spv.specConstantComposite @scc2 (@sc1, @sc2, @sc3) : i32 +} + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite (spv.array) +//===----------------------------------------------------------------------===// + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1.5 : f32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32> + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32> +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = false + spv.specConstant @sc2 spec_id(5) = 42 : i64 + spv.specConstant @sc3 = 1.5 : f32 + // expected-error @+1 {{has incorrect number of operands: expected 4, but provided 3}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<4 x f32> + +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // expected-error @+1 {{has incorrect types of operands: expected 'f32', but provided 'i32'}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32> +} + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite (spv.struct) +//===----------------------------------------------------------------------===// + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // expected-error @+1 {{has incorrect number of operands: expected 2, but provided 3}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1.5 : f32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // expected-error @+1 {{has incorrect types of operands: expected 'i32', but provided 'f32'}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct +} + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite (vector) +//===----------------------------------------------------------------------===// + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1.5 : f32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3xf32> + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3 x f32> +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = false + spv.specConstant @sc2 spec_id(5) = 42 : i64 + spv.specConstant @sc3 = 1.5 : f32 + // expected-error @+1 {{has incorrect number of operands: expected 4, but provided 3}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<4xf32> + +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // expected-error @+1 {{has incorrect types of operands: expected 'f32', but provided 'i32'}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3xf32> +} + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite (spv.coopmatrix) +//===----------------------------------------------------------------------===// + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1.5 : f32 + // expected-error @+1 {{unsupported composite type}} + spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device> +}