diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -36,6 +36,7 @@ static constexpr const char kBranchWeightAttrName[] = "branch_weights"; static constexpr const char kCallee[] = "callee"; static constexpr const char kClusterSize[] = "cluster_size"; +static constexpr const char kControl[] = "control"; static constexpr const char kDefaultValueAttrName[] = "default_value"; static constexpr const char kExecutionScopeAttrName[] = "execution_scope"; static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics"; @@ -161,6 +162,25 @@ return success(); } +/// Parses Function, Selection and Loop control attributes. If no control is +/// specified, "None" is used as a default. +template +static ParseResult +parseControlAttribute(OpAsmParser &parser, OperationState &state, + StringRef attrName = spirv::attributeName()) { + if (succeeded(parser.parseOptionalKeyword(kControl))) { + EnumClass control; + if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) || + parser.parseRParen()) + return failure(); + return success(); + } + // Set control to "None" otherwise. + Builder builder = parser.getBuilder(); + state.addAttribute(attrName, builder.getI32IntegerAttr(0)); + return success(); +} + /// Parses optional memory access attributes attached to a memory access /// operand/pointer. Specifically, parses the following syntax: /// (`[` memory-access `]`)? @@ -2082,12 +2102,8 @@ } static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) { - // TODO: support loop control properly - Builder builder = parser.getBuilder(); - state.addAttribute("loop_control", - builder.getI32IntegerAttr( - static_cast(spirv::LoopControl::None))); - + if (parseControlAttribute(parser, state)) + return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); } @@ -2096,6 +2112,9 @@ auto *op = loopOp.getOperation(); printer << spirv::LoopOp::getOperationName(); + auto control = loopOp.loop_control(); + if (control != spirv::LoopControl::None) + printer << " control(" << spirv::stringifyLoopControl(control) << ")"; printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -2445,12 +2464,8 @@ static ParseResult parseSelectionOp(OpAsmParser &parser, OperationState &state) { - // TODO: support selection control properly - Builder builder = parser.getBuilder(); - state.addAttribute("selection_control", - builder.getI32IntegerAttr( - static_cast(spirv::SelectionControl::None))); - + if (parseControlAttribute(parser, state)) + return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); } @@ -2459,6 +2474,9 @@ auto *op = selectionOp.getOperation(); printer << spirv::SelectionOp::getOperationName(); + auto control = selectionOp.selection_control(); + if (control != spirv::SelectionControl::None) + printer << " control(" << spirv::stringifySelectionControl(control) << ")"; printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -64,11 +64,14 @@ Block *mergeBlock; Block *continueBlock; // nullptr for spv.selection Location loc; - - BlockMergeInfo(Location location) - : mergeBlock(nullptr), continueBlock(nullptr), loc(location) {} - BlockMergeInfo(Location location, Block *m, Block *c = nullptr) - : mergeBlock(m), continueBlock(c), loc(location) {} + uint32_t control; + + BlockMergeInfo(Location location, uint32_t control) + : mergeBlock(nullptr), continueBlock(nullptr), loc(location), + control(control) {} + BlockMergeInfo(Location location, uint32_t control, Block *m, + Block *c = nullptr) + : mergeBlock(m), continueBlock(c), loc(location), control(control) {} }; /// A struct for containing OpLine instruction information. @@ -1681,16 +1684,12 @@ "OpSelectionMerge must specify merge target and selection control"); } - if (static_cast(spirv::SelectionControl::None) != operands[1]) { - return emitError(unknownLoc, - "unimplmented OpSelectionMerge selection control: ") - << operands[2]; - } - auto *mergeBlock = getOrCreateBlock(operands[0]); auto loc = createFileLineColLoc(opBuilder); + auto selectionControl = operands[1]; - if (!blockMergeInfo.try_emplace(curBlock, loc, mergeBlock).second) { + if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock) + .second) { return emitError( unknownLoc, "a block cannot have more than one OpSelectionMerge instruction"); @@ -1709,16 +1708,13 @@ "continue target and loop control"); } - if (static_cast(spirv::LoopControl::None) != operands[2]) { - return emitError(unknownLoc, "unimplmented OpLoopMerge loop control: ") - << operands[2]; - } - auto *mergeBlock = getOrCreateBlock(operands[0]); auto *continueBlock = getOrCreateBlock(operands[1]); auto loc = createFileLineColLoc(opBuilder); + uint32_t loopControl = operands[2]; - if (!blockMergeInfo.try_emplace(curBlock, loc, mergeBlock, continueBlock) + if (!blockMergeInfo + .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock) .second) { return emitError( unknownLoc, @@ -1771,25 +1767,27 @@ /// the `headerBlock` will be redirected to the `mergeBlock`. /// This method will also update `mergeInfo` by remapping all blocks inside to /// the newly cloned ones inside structured control flow op's regions. - static LogicalResult structurize(Location loc, BlockMergeInfoMap &mergeInfo, + static LogicalResult structurize(Location loc, uint32_t control, + BlockMergeInfoMap &mergeInfo, Block *headerBlock, Block *mergeBlock, Block *continueBlock) { - return ControlFlowStructurizer(loc, mergeInfo, headerBlock, mergeBlock, - continueBlock) + return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock, + mergeBlock, continueBlock) .structurizeImpl(); } private: - ControlFlowStructurizer(Location loc, BlockMergeInfoMap &mergeInfo, - Block *header, Block *merge, Block *cont) - : location(loc), blockMergeInfo(mergeInfo), headerBlock(header), - mergeBlock(merge), continueBlock(cont) {} + ControlFlowStructurizer(Location loc, uint32_t control, + BlockMergeInfoMap &mergeInfo, Block *header, + Block *merge, Block *cont) + : location(loc), control(control), blockMergeInfo(mergeInfo), + headerBlock(header), mergeBlock(merge), continueBlock(cont) {} /// Creates a new spv.selection op at the beginning of the `mergeBlock`. - spirv::SelectionOp createSelectionOp(); + spirv::SelectionOp createSelectionOp(uint32_t selectionControl); /// Creates a new spv.loop op at the beginning of the `mergeBlock`. - spirv::LoopOp createLoopOp(); + spirv::LoopOp createLoopOp(uint32_t loopControl); /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. void collectBlocksInConstruct(); @@ -1797,6 +1795,7 @@ LogicalResult structurizeImpl(); Location location; + uint32_t control; BlockMergeInfoMap &blockMergeInfo; @@ -1808,26 +1807,26 @@ }; } // namespace -spirv::SelectionOp ControlFlowStructurizer::createSelectionOp() { +spirv::SelectionOp +ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { // Create a builder and set the insertion point to the beginning of the // merge block so that the newly created SelectionOp will be inserted there. OpBuilder builder(&mergeBlock->front()); - auto control = builder.getI32IntegerAttr( - static_cast(spirv::SelectionControl::None)); + auto control = builder.getI32IntegerAttr(selectionControl); auto selectionOp = builder.create(location, control); selectionOp.addMergeBlock(); return selectionOp; } -spirv::LoopOp ControlFlowStructurizer::createLoopOp() { +spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { // Create a builder and set the insertion point to the beginning of the // merge block so that the newly created LoopOp will be inserted there. OpBuilder builder(&mergeBlock->front()); - // TODO: handle loop control properly - auto loopOp = builder.create(location); + auto control = builder.getI32IntegerAttr(loopControl); + auto loopOp = builder.create(location, control); loopOp.addEntryAndMergeBlock(); return loopOp; @@ -1852,10 +1851,10 @@ Operation *op = nullptr; bool isLoop = continueBlock != nullptr; if (isLoop) { - if (auto loopOp = createLoopOp()) + if (auto loopOp = createLoopOp(control)) op = loopOp.getOperation(); } else { - if (auto selectionOp = createSelectionOp()) + if (auto selectionOp = createSelectionOp(control)) op = selectionOp.getOperation(); } if (!op) @@ -1992,7 +1991,8 @@ // The iterator should be erased before adding a new entry into // blockMergeInfo to avoid iterator invalidation. blockMergeInfo.erase(it); - blockMergeInfo.try_emplace(newHeader, loc, newMerge, newContinue); + blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge, + newContinue); } // The structured selection/loop's entry block does not have arguments. @@ -2096,9 +2096,9 @@ // Erase this case before calling into structurizer, who will update // blockMergeInfo. blockMergeInfo.erase(blockMergeInfo.begin()); - if (failed(ControlFlowStructurizer::structurize(mergeInfo.loc, - blockMergeInfo, headerBlock, - mergeBlock, continueBlock))) + if (failed(ControlFlowStructurizer::structurize( + mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock, + mergeBlock, continueBlock))) return failure(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1573,10 +1573,9 @@ auto emitSelectionMerge = [&]() { emitDebugLine(functionBody, loc); lastProcessedWasMergeInst = true; - // TODO: properly support selection control here encodeInstructionInto( functionBody, spirv::Opcode::OpSelectionMerge, - {mergeID, static_cast(spirv::SelectionControl::None)}); + {mergeID, static_cast(selectionOp.selection_control())}); }; // For structured selection, we cannot have blocks in the selection construct // branching to the selection header block. Entering the selection (and @@ -1636,10 +1635,9 @@ auto emitLoopMerge = [&]() { emitDebugLine(functionBody, loc); lastProcessedWasMergeInst = true; - // TODO: properly support loop control here encodeInstructionInto( functionBody, spirv::Opcode::OpLoopMerge, - {mergeID, continueID, static_cast(spirv::LoopControl::None)}); + {mergeID, continueID, static_cast(loopOp.loop_control())}); }; if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) return failure(); diff --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir @@ -119,8 +119,8 @@ // CHECK: spv.Branch ^bb1 // CHECK-NEXT: ^bb1: -// CHECK-NEXT: spv.loop - spv.loop { +// CHECK-NEXT: spv.loop control(Unroll) + spv.loop control(Unroll) { // CHECK-NEXT: spv.Branch ^bb1 spv.Branch ^header @@ -140,8 +140,8 @@ spv.Store "Function" %jvar, %zero : i32 // CHECK-NEXT: spv.Branch ^bb3 // CHECK-NEXT: ^bb3: -// CHECK-NEXT: spv.loop - spv.loop { +// CHECK-NEXT: spv.loop control(DontUnroll) + spv.loop control(DontUnroll) { // CHECK-NEXT: spv.Branch ^bb1 spv.Branch ^header diff --git a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/selection.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/selection.mlir @@ -11,10 +11,10 @@ %two = spv.constant 2: i32 %var = spv.Variable init(%zero) : !spv.ptr -// CHECK-NEXT: spv.selection { +// CHECK-NEXT: spv.selection control(Flatten) // CHECK-NEXT: spv.constant 0 // CHECK-NEXT: spv.Variable - spv.selection { + spv.selection control(Flatten) { // CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^bb1, ^bb2 spv.BranchConditional %cond [5, 10], ^then, ^else diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -317,6 +317,16 @@ // ----- +// CHECK-LABEL: @loop_with_control +func @loop_with_control() -> () { + // CHECK: spv.loop control(Unroll) + spv.loop control(Unroll) { + } + return +} + +// ----- + func @wrong_merge_block() -> () { // expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}} spv.loop { @@ -718,6 +728,16 @@ // ----- +// CHECK-LABEL: @selection_with_control +func @selection_with_control() -> () { + // CHECK: spv.selection control(Flatten) + spv.selection control(Flatten) { + } + return +} + +// ----- + func @wrong_merge_block() -> () { // expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}} spv.selection {