diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3170,6 +3170,7 @@ def SPV_OC_OpSpecConstantFalse : I32EnumAttrCase<"OpSpecConstantFalse", 49>; def SPV_OC_OpSpecConstant : I32EnumAttrCase<"OpSpecConstant", 50>; def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>; +def SPV_OC_OpSpecConstantOperation : I32EnumAttrCase<"OpSpecConstantOperation", 52>; def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>; def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>; def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>; @@ -3314,7 +3315,8 @@ SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, - SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, + SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOperation, + SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3445,9 +3445,8 @@ return constOp.emitOpError("invalid enclosed op"); for (auto operand : enclosedOp.getOperands()) - if (!isa( - operand.getDefiningOp())) + if (!isa(operand.getDefiningOp())) return constOp.emitOpError( "invalid operand, must be defined by a constant operation"); diff --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization.cpp --- a/mlir/lib/Target/SPIRV/Deserialization.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/SPIRV/Deserialization.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" @@ -28,6 +29,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -132,6 +134,14 @@ SmallVector memberDecorationsInfo; }; +/// A struct that collects the info needed to materialize/emit a +/// SpecConstantOperation op. +struct SpecConstOperationMaterializationInfo { + spirv::Opcode enclodesOpcode; + uint32_t resultTypeID; + SmallVector enclosedOpOperands; +}; + //===----------------------------------------------------------------------===// // Deserializer Declaration //===----------------------------------------------------------------------===// @@ -216,9 +226,14 @@ /// Gets the constant's attribute and type associated with the given . Optional> getConstant(uint32_t id); - /// Gets the constant's integer attribute with the given . Returns a null - /// IntegerAttr if the given is not registered or does not correspond to an - /// integer constant. + /// Gets the info needed to materialize the spec constant operation op + /// associated with the given . + Optional + getSpecConstantOperation(uint32_t id); + + /// Gets the constant's integer attribute with the given . Returns a + /// null IntegerAttr if the given is not registered or does not correspond + /// to an integer constant. IntegerAttr getConstantInt(uint32_t id); /// Returns a symbol to be used for the function name with the given @@ -305,8 +320,20 @@ /// `operands`. LogicalResult processConstantComposite(ArrayRef operands); + /// Processes a SPIR-V OpSpecConstantComposite instruction with the given + /// `operands`. LogicalResult processSpecConstantComposite(ArrayRef operands); + /// Processes a SPIR-V OpSpecConstantOperation instruction with the given + /// `operands`. + LogicalResult processSpecConstantOperation(ArrayRef operands); + + /// Materializes/emits an OpSpecConstantOperation instruction. + Value materializeSpecConstantOperation(uint32_t resultID, + spirv::Opcode enclosedOpcode, + uint32_t resultTypeID, + ArrayRef enclosedOpOperands); + /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); @@ -534,6 +561,11 @@ // Result to composite spec constant mapping. DenseMap specConstCompositeMap; + /// Result to info needed to materialize an OpSpecConstantOperation + /// mapping. + DenseMap + specConstOperationMap; + // Result to variable mapping. DenseMap globalVariableMap; @@ -1036,6 +1068,14 @@ return constIt->getSecond(); } +Optional +Deserializer::getSpecConstantOperation(uint32_t id) { + auto constIt = specConstOperationMap.find(id); + if (constIt == specConstOperationMap.end()) + return llvm::None; + return constIt->getSecond(); +} + std::string Deserializer::getFunctionSymbol(uint32_t id) { auto funcName = nameMap.lookup(id).str(); if (funcName.empty()) { @@ -1745,6 +1785,91 @@ return success(); } +LogicalResult +Deserializer::processSpecConstantOperation(ArrayRef operands) { + if (operands.size() < 3) + return emitError(unknownLoc, "OpConstantOperation must have type , " + "result , and operand opcode"); + + uint32_t resultTypeID = operands[0]; + + if (!getType(resultTypeID)) + return emitError(unknownLoc, "undefined result type from ") + << resultTypeID; + + uint32_t resultID = operands[1]; + spirv::Opcode enclosedOpcode = static_cast(operands[2]); + auto emplaceResult = specConstOperationMap.try_emplace( + resultID, + SpecConstOperationMaterializationInfo{ + enclosedOpcode, resultTypeID, + SmallVector{operands.begin() + 3, operands.end()}}); + + if (!emplaceResult.second) + return emitError(unknownLoc, "value with : ") + << resultID << " is probably defined before."; + + return success(); +} + +Value Deserializer::materializeSpecConstantOperation( + uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, + ArrayRef enclosedOpOperands) { + + Type resultType = getType(resultTypeID); + + // Instructions wrapped by OpSpecConstantOp need an ID for their + // Deserializer::processOp(...) to emit the corresponding SPIR-V + // dialect wrapped op. For that purpose, a new value map is created and "fake" + // ID in that map is assigned to the result of the enclosed instruction. Note + // that there is no need to update this fake ID since we only need to + // reference the created Value for the enclosed op from the spv::YieldOp + // created later in this method (both of which are the only values in their + // region: the SpecConstantOperation's region). If we encounter another + // SpecConstantOperation in the module, we simply re-use the fake ID since the + // previous Value assigned to it isn't visible in the current scope anyway. + DenseMap newValueMap; + llvm::SaveAndRestore> valueMapGuard(valueMap, + newValueMap); + constexpr uint32_t fakeID = static_cast(-3); + + SmallVector enclosedOpResultTypeAndOperands; + enclosedOpResultTypeAndOperands.push_back(resultTypeID); + enclosedOpResultTypeAndOperands.push_back(fakeID); + enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(), + enclosedOpOperands.end()); + + // Process enclosed instruction before creating the enclosing + // specConstantOperation (and its region). This way, references to constants, + // global variables, and spec constants will be materialized outside the new + // op's region. For more info, see Deserializer::getValue's implementation. + if (failed( + processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands))) + return Value(); + + // Since the enclosed op is emitted in the current block, split it in a + // separate new block. + Block *enclosedBlock = curBlock->splitBlock(&curBlock->back()); + + auto loc = createFileLineColLoc(opBuilder); + auto specConstOperationOp = + opBuilder.create(loc, resultType); + + Region &body = specConstOperationOp.body(); + // Move the new block into SpecConstantOperation's body. + body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(), + Region::iterator(enclosedBlock)); + Block &block = body.back(); + + // RAII guard to reset the insertion point to the module's region after + // deserializing the body of the specConstantOperation. + OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); + opBuilder.setInsertionPointToEnd(&block); + + opBuilder.create(loc, block.front().getResult(0)); + return specConstOperationOp.getResult(); +} + LogicalResult Deserializer::processConstantNull(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, @@ -2378,6 +2503,12 @@ opBuilder.getSymbolRefAttr(constCompositeOp.getOperation())); return referenceOfOp.reference(); } + if (auto specConstOperationInfo = getSpecConstantOperation(id)) { + return materializeSpecConstantOperation( + id, specConstOperationInfo->enclodesOpcode, + specConstOperationInfo->resultTypeID, + specConstOperationInfo->enclosedOpOperands); + } if (auto undef = getUndefType(id)) { return opBuilder.create(unknownLoc, undef); } @@ -2483,6 +2614,8 @@ return processConstantComposite(operands); case spirv::Opcode::OpSpecConstantComposite: return processSpecConstantComposite(operands); + case spirv::Opcode::OpSpecConstantOperation: + return processSpecConstantOperation(operands); case spirv::Opcode::OpConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantTrue: diff --git a/mlir/lib/Target/SPIRV/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization.cpp --- a/mlir/lib/Target/SPIRV/Serialization.cpp +++ b/mlir/lib/Target/SPIRV/Serialization.cpp @@ -204,6 +204,9 @@ LogicalResult processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); + LogicalResult + processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); + /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA /// value to use with other operations. The SPIR-V spec recommends that /// OpUndef be generated at module level. The serialization generates an @@ -711,6 +714,49 @@ return processName(resultID, op.sym_name()); } +LogicalResult +Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { + uint32_t typeID = 0; + if (failed(processType(op.getLoc(), op.getType(), typeID))) { + return failure(); + } + + auto resultID = getNextID(); + + SmallVector operands; + operands.push_back(typeID); + operands.push_back(resultID); + + Block &block = op.getRegion().getBlocks().front(); + Operation &enclosedOp = block.getOperations().front(); + + std::string enclosedOpName; + llvm::raw_string_ostream rss(enclosedOpName); + rss << "Op" << enclosedOp.getName().stripDialect(); + auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); + + if (!enclosedOpcode) { + op.emitError("Couldn't find op code for op ") + << enclosedOp.getName().getStringRef(); + return failure(); + } + + operands.push_back(static_cast(enclosedOpcode.getValue())); + + // Append operands to the enclosed op to the list of operands. + for (Value operand : enclosedOp.getOperands()) { + uint32_t id = getValueID(operand); + assert(id && "use before def!"); + operands.push_back(id); + } + + encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpSpecConstantOperation, operands); + valueIDMap[op.getResult()] = resultID; + + return success(); +} + LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { auto undefType = op.getType(); auto &id = undefValIDMap[undefType]; @@ -1929,6 +1975,9 @@ .Case([&](spirv::SpecConstantCompositeOp op) { return processSpecConstantCompositeOp(op); }) + .Case([&](spirv::SpecConstantOperationOp op) { + return processSpecConstantOperationOp(op); + }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -780,6 +780,20 @@ // ----- +spv.module Logical GLSL450 { + spv.specConstant @sc = 42 : i32 + + spv.func @foo() -> i32 "None" { + // CHECK: [[SC:%.*]] = spv.mlir.referenceof @sc + %0 = spv.mlir.referenceof @sc : i32 + // CHECK: spv.SpecConstantOperation wraps "spv.ISub"([[SC]], [[SC]]) : (i32, i32) -> i32 + %1 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %0) : (i32, i32) -> i32 + spv.ReturnValue %1 : i32 + } +} + +// ----- + spv.module Logical GLSL450 { spv.func @foo() -> i32 "None" { %0 = spv.constant 1: i32 diff --git a/mlir/test/Target/SPIRV/spec-constant.mlir b/mlir/test/Target/SPIRV/spec-constant.mlir --- a/mlir/test/Target/SPIRV/spec-constant.mlir +++ b/mlir/test/Target/SPIRV/spec-constant.mlir @@ -85,3 +85,34 @@ // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32> spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32> } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + + spv.specConstant @sc_i32_1 = 1 : i32 + + spv.func @use_composite() -> (i32) "None" { + // CHECK: [[USE1:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32 + // CHECK: [[USE2:%.*]] = spv.constant 0 : i32 + + // CHECK: [[RES1:%.*]] = spv.SpecConstantOperation wraps "spv.ISub"([[USE1]], [[USE2]]) : (i32, i32) -> i32 + + // CHECK: [[USE3:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32 + // CHECK: [[USE4:%.*]] = spv.constant 0 : i32 + + // CHECK: [[RES2:%.*]] = spv.SpecConstantOperation wraps "spv.ISub"([[USE3]], [[USE4]]) : (i32, i32) -> i32 + + %0 = spv.mlir.referenceof @sc_i32_1 : i32 + %1 = spv.constant 0 : i32 + %2 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %1) : (i32, i32) -> i32 + + // CHECK: [[RES3:%.*]] = spv.SpecConstantOperation wraps "spv.IMul"([[RES1]], [[RES2]]) : (i32, i32) -> i32 + %3 = spv.SpecConstantOperation wraps "spv.IMul"(%2, %2) : (i32, i32) -> i32 + + // Make sure deserialization continues from the right place after creating + // the previous op. + // CHECK: spv.ReturnValue [[RES3]] + spv.ReturnValue %3 : i32 + } +}