diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -36,6 +36,7 @@ static constexpr const char kBranchWeightAttrName[] = "branch_weights"; static constexpr const char kCallee[] = "callee"; static constexpr const char kClusterSize[] = "cluster_size"; +static constexpr const char kControl[] = "control"; static constexpr const char kDefaultValueAttrName[] = "default_value"; static constexpr const char kExecutionScopeAttrName[] = "execution_scope"; static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics"; @@ -161,6 +162,25 @@ return success(); } +/// Parses Function, Selection and Loop control attributes. If no control is +/// specified, "None" is used as a default. +template +static ParseResult +parseControlAttribute(OpAsmParser &parser, OperationState &state, + StringRef attrName = spirv::attributeName()) { + if (succeeded(parser.parseOptionalKeyword(kControl))) { + EnumClass control; + if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) || + parser.parseRParen()) + return failure(); + return success(); + } + // Set control to "None" otherwise. + Builder builder = parser.getBuilder(); + state.addAttribute(attrName, builder.getI32IntegerAttr(0)); + return success(); +} + /// Parses optional memory access attributes attached to a memory access /// operand/pointer. Specifically, parses the following syntax: /// (`[` memory-access `]`)? @@ -2077,12 +2097,8 @@ } static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) { - // TODO: support loop control properly - Builder builder = parser.getBuilder(); - state.addAttribute("loop_control", - builder.getI32IntegerAttr( - static_cast(spirv::LoopControl::None))); - + if (parseControlAttribute(parser, state)) + return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); } @@ -2091,6 +2107,9 @@ auto *op = loopOp.getOperation(); printer << spirv::LoopOp::getOperationName(); + auto control = loopOp.loop_control(); + if (control != spirv::LoopControl::None) + printer << " control(" << spirv::stringifyLoopControl(control) << ")"; printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -2440,12 +2459,8 @@ static ParseResult parseSelectionOp(OpAsmParser &parser, OperationState &state) { - // TODO: support selection control properly - Builder builder = parser.getBuilder(); - state.addAttribute("selection_control", - builder.getI32IntegerAttr( - static_cast(spirv::SelectionControl::None))); - + if (parseControlAttribute(parser, state)) + return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); } @@ -2454,6 +2469,9 @@ auto *op = selectionOp.getOperation(); printer << spirv::SelectionOp::getOperationName(); + auto control = selectionOp.selection_control(); + if (control != spirv::SelectionControl::None) + printer << " control(" << spirv::stringifySelectionControl(control) << ")"; printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -317,6 +317,16 @@ // ----- +// CHECK-LABEL: @loop_with_control +func @loop_with_control() -> () { + // CHECK: spv.loop control(Unroll) + spv.loop control(Unroll) { + } + return +} + +// ----- + func @wrong_merge_block() -> () { // expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}} spv.loop { @@ -718,6 +728,16 @@ // ----- +// CHECK-LABEL: @selection_with_control +func @selection_with_control() -> () { + // CHECK: spv.selection control(Flatten) + spv.selection control(Flatten) { + } + return +} + +// ----- + func @wrong_merge_block() -> () { // expected-error @+1 {{last block must be the merge block with only one 'spv._merge' op}} spv.selection {