diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -472,7 +472,7 @@ let summary = "Reference a specialization constant."; let description = [{ - Specialization constant in module scope are defined using symbol names. + Specialization constants in module scope are defined using symbol names. This op generates an SSA value that can be used to refer to the symbol within function scope for use in ops that expect an SSA value. This operation has no corresponding SPIR-V instruction; it's merely used diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2568,17 +2568,27 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { - auto specConstOp = dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(), - referenceOfOp.spec_const())); - if (!specConstOp) { - return referenceOfOp.emitOpError("expected spv.specConstant symbol"); - } - if (referenceOfOp.reference().getType() != - specConstOp.default_value().getType()) { + auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( + referenceOfOp.getParentOp(), referenceOfOp.spec_const()); + Type constType; + + auto specConstOp = dyn_cast_or_null(specConstSym); + if (specConstOp) + constType = specConstOp.default_value().getType(); + + auto specConstCompositeOp = + dyn_cast_or_null(specConstSym); + if (specConstCompositeOp) + constType = specConstCompositeOp.type(); + + if (!specConstOp && !specConstCompositeOp) + return referenceOfOp.emitOpError( + "expected spv.specConstant or spv.SpecConstantComposite symbol"); + + if (referenceOfOp.reference().getType() != constType) return referenceOfOp.emitOpError("result type mismatch with the referenced " "specialization constant's type"); - } + return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -187,6 +187,11 @@ return specConstMap.lookup(id); } + /// Gets the composite specialization constant with the given result . + spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) { + return specConstCompositeMap.lookup(id); + } + /// Creates a spirv::SpecConstantOp. spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, Attribute defaultValue); @@ -461,9 +466,12 @@ /// (and type) here. Later when it's used, we materialize the constant. DenseMap> constantMap; - // Result to variable mapping. + // Result to spec constant mapping. DenseMap specConstMap; + // Result to composite spec constant mapping. + DenseMap specConstCompositeMap; + // Result to variable mapping. DenseMap globalVariableMap; @@ -1565,7 +1573,8 @@ << operands[0]; } - auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1])); + auto resultID = operands[1]; + auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); SmallVector elements; elements.reserve(operands.size() - 2); @@ -1574,9 +1583,10 @@ elements.push_back(opBuilder.getSymbolRefAttr(elementInfo)); } - opBuilder.create( + auto op = opBuilder.create( unknownLoc, TypeAttr::get(resultType), symName, opBuilder.getArrayAttr(elements)); + specConstCompositeMap[resultID] = op; return success(); } @@ -2208,6 +2218,12 @@ opBuilder.getSymbolRefAttr(constOp.getOperation())); return referenceOfOp.reference(); } + if (auto constCompositeOp = getSpecConstantComposite(id)) { + auto referenceOfOp = opBuilder.create( + unknownLoc, constCompositeOp.type(), + opBuilder.getSymbolRefAttr(constCompositeOp.getOperation())); + return referenceOfOp.reference(); + } if (auto undef = getUndefType(id)) { return opBuilder.create(unknownLoc, undef); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir @@ -12,6 +12,9 @@ // CHECK: spv.specConstant @sc_float spec_id(5) = 1.000000e+00 : f32 spv.specConstant @sc_float spec_id(5) = 1. : f32 + // CHECK: spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32> + spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32> + // CHECK-LABEL: @use spv.func @use() -> (i32) "None" { // We materialize a `spv._reference_of` op at every use of a @@ -24,6 +27,23 @@ %1 = spv.IAdd %0, %0 : i32 spv.ReturnValue %1 : i32 } + + // CHECK-LABEL: @use + spv.func @use_composite() -> (i32) "None" { + // We materialize a `spv._reference_of` op at every use of a + // specialization constant in the deserializer. So two ops here. + // CHECK: %[[USE1:.*]] = spv._reference_of @scc : !spv.array<2 x i32> + // CHECK: %[[ITM0:.*]] = spv.CompositeExtract %[[USE1]][0 : i32] : !spv.array<2 x i32> + // CHECK: %[[USE2:.*]] = spv._reference_of @scc : !spv.array<2 x i32> + // CHECK: %[[ITM1:.*]] = spv.CompositeExtract %[[USE2]][1 : i32] : !spv.array<2 x i32> + // CHECK: spv.IAdd %[[ITM0]], %[[ITM1]] + + %0 = spv._reference_of @scc : !spv.array<2 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32> + %2 = spv.CompositeExtract %0[1 : i32] : !spv.array<2 x i32> + %3 = spv.IAdd %1, %2 : i32 + spv.ReturnValue %3 : i32 + } } // ----- diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -496,6 +496,8 @@ spv.specConstant @sc2 = 42 : i64 spv.specConstant @sc3 = 1.5 : f32 + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct + // CHECK-LABEL: @reference spv.func @reference() -> i1 "None" { // CHECK: spv._reference_of @sc1 : i1 @@ -503,6 +505,14 @@ spv.ReturnValue %0 : i1 } + // CHECK-LABEL: @reference_composite + spv.func @reference_composite() -> i1 "None" { + // CHECK: spv._reference_of @scc : !spv.struct + %0 = spv._reference_of @scc : !spv.struct + %1 = spv.CompositeExtract %0[0 : i32] : !spv.struct + spv.ReturnValue %1 : i1 + } + // CHECK-LABEL: @initialize spv.func @initialize() -> i64 "None" { // CHECK: spv._reference_of @sc2 : i64 @@ -534,9 +544,21 @@ // ----- +spv.specConstant @sc = 5 : i32 +spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32> + +func @reference_of_composite() { + // CHECK: spv._reference_of @scc : !spv.array<1 x i32> + %0 = spv._reference_of @scc : !spv.array<1 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<1 x i32> + return +} + +// ----- + spv.module Logical GLSL450 { spv.func @foo() -> () "None" { - // expected-error @+1 {{expected spv.specConstant symbol}} + // expected-error @+1 {{expected spv.specConstant or spv.SpecConstantComposite symbol}} %0 = spv._reference_of @sc : i32 spv.Return } @@ -555,6 +577,18 @@ // ----- +spv.module Logical GLSL450 { + spv.specConstant @sc = 42 : i32 + spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32> + spv.func @foo() -> () "None" { + // expected-error @+1 {{result type mismatch with the referenced specialization constant's type}} + %0 = spv._reference_of @scc : f32 + spv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.specConstant //===----------------------------------------------------------------------===//