diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3166,6 +3166,7 @@ def SPV_OC_OpSpecConstantFalse : I32EnumAttrCase<"OpSpecConstantFalse", 49>; def SPV_OC_OpSpecConstant : I32EnumAttrCase<"OpSpecConstant", 50>; def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>; +def SPV_OC_OpSpecConstantOperation : I32EnumAttrCase<"OpSpecConstantOperation", 52>; def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>; def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>; def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>; @@ -3310,7 +3311,8 @@ SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, - SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, + SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOperation, + SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3457,9 +3457,8 @@ return constOp.emitOpError("invalid enclosed op"); for (auto operand : enclosedOp.getOperands()) - if (!isa( - operand.getDefiningOp())) + if (!isa(operand.getDefiningOp())) return constOp.emitOpError( "invalid operand, must be defined by a constant operation"); diff --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization.cpp --- a/mlir/lib/Target/SPIRV/Deserialization.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization.cpp @@ -295,6 +295,8 @@ LogicalResult processSpecConstantComposite(ArrayRef operands); + LogicalResult processSpecConstantOperation(ArrayRef operands); + /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); @@ -1721,6 +1723,74 @@ return success(); } +LogicalResult +Deserializer::processSpecConstantOperation(ArrayRef operands) { + if (operands.size() < 3) + return emitError(unknownLoc, "OpConstantOperation must have type , " + "result , and operand opcode"); + + uint32_t resultTypeID = operands[0]; + Type resultType = getType(resultTypeID); + if (!resultType) + return emitError(unknownLoc, "undefined result type from ") + << operands[0]; + + // Instructions wrapped by OpSpecConstantOp need an ID for their + // Deserializer::processOp(...) to emit the corresponding spv dialect + // wrapped op. Since deserialization methods add created spv Values to + // `valueMap`, we need to choose an ID that doesn't clash with the IDs already + // used by the input SPIR-V module's values. That's why we use a "fake" ID + // from the high range of uint32_t. Note that there is no need to update this + // fake ID since we only need to reference the creaed Value for the enclosed + // op from the spv::YieldOp created later in this method (both of which are + // the only values in their region: the SpecConstantOperation's region). If we + // encounter another SpecConstantOperation in the module, we simply re-use the + // fake ID since the previous Value assigned to it isn't visible in the + // current scope anyway. + constexpr uint32_t nextFakeID = static_cast(-3); + + SmallVector enclosedOpResultTypeAndOperands; + enclosedOpResultTypeAndOperands.push_back(resultTypeID); + enclosedOpResultTypeAndOperands.push_back(nextFakeID); + enclosedOpResultTypeAndOperands.append(operands.begin() + 3, operands.end()); + + // Process enclosed instruction before creating the enclosing + // specConstantOperation (and its region). This way, references to constants, + // global variables, and spec constants will be materialzied outside the new + // op's region. For more info, see Deserializer::getValue's implementation. + spirv::Opcode enclosedOpcode = static_cast(operands[2]); + if (failed( + processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands))) + return emitError(unknownLoc, "failed to add op with opcode: ") + << static_cast(enclosedOpcode); + + // Since the enclosed op is emitted in the current block, split it in a + // separate new block. + Block *enclosedBlock = curBlock->splitBlock(&curBlock->back()); + + auto loc = createFileLineColLoc(opBuilder); + auto specConstOperationOp = + opBuilder.create(loc, resultType); + + Region &body = specConstOperationOp.body(); + // Move the new block into SpecConstantOperation's body. + body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(), + Region::iterator(enclosedBlock)); + Block &block = body.back(); + + // RAII guard to reset the insertion point to the module's region after + // deserializing the body of the specConstantOperation. + OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); + opBuilder.setInsertionPointToEnd(&block); + + opBuilder.create(loc, block.front().getResult(0)); + + uint32_t resultID = operands[1]; + valueMap[resultID] = specConstOperationOp.getResult(); + + return success(); +} + LogicalResult Deserializer::processConstantNull(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, @@ -2459,6 +2529,8 @@ return processConstantComposite(operands); case spirv::Opcode::OpSpecConstantComposite: return processSpecConstantComposite(operands); + case spirv::Opcode::OpSpecConstantOperation: + return processSpecConstantOperation(operands); case spirv::Opcode::OpConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantTrue: diff --git a/mlir/lib/Target/SPIRV/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization.cpp --- a/mlir/lib/Target/SPIRV/Serialization.cpp +++ b/mlir/lib/Target/SPIRV/Serialization.cpp @@ -204,6 +204,9 @@ LogicalResult processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); + LogicalResult + processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); + /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA /// value to use with other operations. The SPIR-V spec recommends that /// OpUndef be generated at module level. The serialization generates an @@ -703,6 +706,49 @@ return processName(resultID, op.sym_name()); } +LogicalResult +Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) { + uint32_t typeID = 0; + if (failed(processType(op.getLoc(), op.getType(), typeID))) { + return failure(); + } + + auto resultID = getNextID(); + + SmallVector operands; + operands.push_back(typeID); + operands.push_back(resultID); + + Block &block = op.getRegion().getBlocks().front(); + Operation &enclosedOp = block.getOperations().front(); + + std::string enclosedOpName; + llvm::raw_string_ostream rss(enclosedOpName); + rss << "Op" << enclosedOp.getName().stripDialect(); + auto enclosedOpcode = spirv::symbolizeOpcode(rss.str()); + + if (!enclosedOpcode) { + op.emitError("Couldn't find op code for op ") + << enclosedOp.getName().getStringRef(); + return failure(); + } + + operands.push_back(static_cast(enclosedOpcode.getValue())); + + // Append operands to the enclosed op to the list of operands. + for (Value operand : enclosedOp.getOperands()) { + uint32_t id = getValueID(operand); + assert(id && "use before def!"); + operands.push_back(id); + } + + encodeInstructionInto(functionBody, spirv::Opcode::OpSpecConstantOperation, + operands); + valueIDMap[op.getResult()] = resultID; + + return success(); +} + LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { auto undefType = op.getType(); auto &id = undefValIDMap[undefType]; @@ -1921,6 +1967,9 @@ .Case([&](spirv::SpecConstantCompositeOp op) { return processSpecConstantCompositeOp(op); }) + .Case([&](spirv::SpecConstantOperationOp op) { + return processSpecConstantOperationOp(op); + }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -780,6 +780,20 @@ // ----- +spv.module Logical GLSL450 { + spv.specConstant @sc = 42 : i32 + + spv.func @foo() -> i32 "None" { + // CHECK: [[SC:%.*]] = spv.mlir.referenceof @sc + %0 = spv.mlir.referenceof @sc : i32 + // CHECK: spv.SpecConstantOperation wraps "spv.ISub"([[SC]], [[SC]]) : (i32, i32) -> i32 + %1 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %0) : (i32, i32) -> i32 + spv.ReturnValue %1 : i32 + } +} + +// ----- + spv.module Logical GLSL450 { spv.func @foo() -> i32 "None" { %0 = spv.constant 1: i32 diff --git a/mlir/test/Target/SPIRV/spec-constant.mlir b/mlir/test/Target/SPIRV/spec-constant.mlir --- a/mlir/test/Target/SPIRV/spec-constant.mlir +++ b/mlir/test/Target/SPIRV/spec-constant.mlir @@ -85,3 +85,27 @@ // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32> spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32> } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + + spv.specConstant @sc_i32_1 = 1 : i32 + + spv.func @use_composite() -> (i32) "None" { + // CHECK: [[USE1:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32 + // CHECK: [[USE2:%.*]] = spv.mlir.referenceof @sc_i32_1 : i32 + %0 = spv.mlir.referenceof @sc_i32_1 : i32 + + // CHECK: [[RES1:%.*]] = spv.SpecConstantOperation wraps "spv.ISub"([[USE1]], [[USE2]]) : (i32, i32) -> i32 + %2 = spv.SpecConstantOperation wraps "spv.ISub"(%0, %0) : (i32, i32) -> i32 + + // CHECK: [[RES2:%.*]] = spv.SpecConstantOperation wraps "spv.IMul"([[RES1]], [[RES1]]) : (i32, i32) -> i32 + %3 = spv.SpecConstantOperation wraps "spv.IMul"(%2, %2) : (i32, i32) -> i32 + + // Make sure deserialization continues from the right place after creating + // the previous op. + // CHECK: spv.ReturnValue [[RES2]] + spv.ReturnValue %3 : i32 + } +}