diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -235,7 +235,7 @@ } /// Return the `index`-th region iteration argument. BlockArgument getRegionIterArg(unsigned index) { - assert(index < getNumRegionIterArgs() && + assert(index < getNumRegionIterArgs() && "expected an index less than the number of region iter args"); return getBody()->getArguments().drop_front(getNumInductionVars())[index]; } @@ -434,7 +434,7 @@ ``` Example with thread_dim_mapping attribute: - + ```mlir // // Sequential context. @@ -456,7 +456,7 @@ ``` Example with privatized tensors: - + ```mlir %t0 = ... %t1 = ... @@ -527,8 +527,8 @@ return getBody()->getArguments().drop_front(getRank()); } - /// Return the thread indices in the order specified by the - /// thread_dim_mapping attribute. Return failure is + /// Return the thread indices in the order specified by the + /// thread_dim_mapping attribute. Return failure is /// thread_dim_mapping is not a valid permutation. FailureOr> getPermutedThreadIndices(); @@ -988,13 +988,70 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +def SwitchOp : SCF_Op<"switch", [RecursiveMemoryEffects, + SingleBlockImplicitTerminator<"scf::YieldOp">, + DeclareOpInterfaceMethods]> { + let summary = "switch-case operation"; + let description = [{ + The `scf.switch` is a control-flow operation that branches to one of the + given regions based on the values of the argument and the cases. The + argument is always of type `index`. + + Example: + + ```mlir + %0 = scf.switch %arg0 : index -> i32 + case 2 { + %1 = arith.constant 10 : i32 + scf.yield %1 : i32 + } + case 5 { + %2 = arith.constant 20 : i32 + scf.yield %2 : i32 + } + default { + %3 = arith.constant 30 : i32 + scf.yield %3 : i32 + } + ``` + }]; + + let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$defaultRegion, + VariadicRegion>:$caseRegions); + + let assemblyFormat = [{ + $arg attr-dict (`->` type($results)^)? + custom($cases, $caseRegions) `\n` + `` `default` $defaultRegion + }]; + + let extraClassDeclaration = [{ + /// Get the number of cases. + unsigned getNumCases(); + + /// Get the default region body. + Block &getDefaultBlock(); + + /// Get the body of a case region. + Block &getCaseBlock(unsigned idx); + }]; + + let hasRegionVerifier = 1; +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["ExecuteRegionOp, ForOp", - "IfOp, ParallelOp, WhileOp"]>]> { + ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "ParallelOp", "SwitchOp", + "WhileOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ "scf.yield" yields an SSA value from the SCF dialect op region and diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3386,6 +3386,90 @@ WhileCmpCond, WhileUnusedResult>(context); } +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +/// Parse the case regions and values. +static ParseResult +parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, + SmallVectorImpl> &caseRegions) { + SmallVector caseValues; + while (succeeded(p.parseOptionalKeyword("case"))) { + int64_t value; + Region ®ion = + *caseRegions.emplace_back(std::make_unique()).get(); + if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{})) + return failure(); + caseValues.push_back(value); + } + cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); + return success(); +} + +/// Print the case regions and values. +static void printSwitchCases(OpAsmPrinter &p, Operation *op, + DenseI64ArrayAttr cases, RegionRange caseRegions) { + for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { + p.printNewline(); + p << "case " << value << ' '; + p.printRegion(*region, /*printEntryBlockArgs=*/false); + } +} + +LogicalResult scf::SwitchOp::verifyRegions() { + if (getCases().size() != getCaseRegions().size()) { + return emitOpError("has ") + << getCaseRegions().size() << " case regions but " + << getCases().size() << " case values"; + } + + DenseSet valueSet; + for (int64_t value : getCases()) + if (!valueSet.insert(value).second) + return emitOpError("has duplicate case value: ") << value; + + return success(); +} + +unsigned scf::SwitchOp::getNumCases() { return getCases().size(); } + +Block &scf::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); } + +Block &scf::SwitchOp::getCaseBlock(unsigned idx) { + assert(idx < getNumCases() && "case index out-of-bounds"); + return getCaseRegions()[idx].front(); +} + +void SwitchOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl &successors) { + // All regions branch back to the parent op. + if (index) { + successors.emplace_back(getResults()); + return; + } + + // If a constant was not provided, all regions are possible successors. + auto operandValue = operands.front().dyn_cast_or_null(); + if (!operandValue) { + for (Region &caseRegion : getCaseRegions()) + successors.emplace_back(&caseRegion); + successors.emplace_back(&getDefaultRegion()); + return; + } + + // Otherwise, try to find a case with a matching value. If not, the default + // region is the only successor. + for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) { + if (caseValue == operandValue.getInt()) { + successors.emplace_back(&caseRegion); + return; + } + } + successors.emplace_back(&getDefaultRegion()); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -428,7 +428,7 @@ func.func @yield_invalid_parent_op() { "my.op"() ({ - // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.parallel, scf.while'}} + // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.parallel, scf.switch, scf.while'}} scf.yield }) : () -> () return @@ -572,3 +572,55 @@ } return } + +// ----- + +func.func @switch_wrong_case_count(%arg0: index) { + // expected-error @below {{'scf.switch' op has 0 case regions but 1 case values}} + "scf.switch"(%arg0) ({ + scf.yield + }) {cases = array} : (index) -> () + return +} + +// ----- + +func.func @switch_duplicate_case(%arg0: index) { + // expected-error @below {{'scf.switch' op has duplicate case value: 0}} + scf.switch %arg0 + case 0 { + scf.yield + } + case 0 { + scf.yield + } + default { + scf.yield + } + return +} + +// ----- + +func.func @switch_wrong_types(%arg0: index) { + // expected-error @below {{region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 0}} + scf.switch %arg0 + default { + scf.yield %arg0 : index + } + return +} + +// ----- + +func.func @switch_wrong_types(%arg0: index, %arg1: i32) { + // expected-error @below {{along control flow edge from Region #1 to parent results: source type #0 'i32' should match input type #0 'index'}} + scf.switch %arg0 -> index + case 0 { + scf.yield %arg1 : i32 + } + default { + scf.yield %arg0 : index + } + return +} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -346,3 +346,36 @@ } {thread_dim_mapping = [42]} return } + +// CHECK-LABEL: @switch +func.func @switch(%arg0: index) -> i32 { + // CHECK: %{{.*}} = scf.switch %arg0 -> i32 + %0 = scf.switch %arg0 -> i32 + // CHECK-NEXT: case 2 { + case 2 { + // CHECK-NEXT: arith.constant + %c10_i32 = arith.constant 10 : i32 + // CHECK-NEXT: scf.yield %{{.*}} : i32 + scf.yield %c10_i32 : i32 + // CHECK-NEXT: } + } + // CHECK-NEXT: case 5 { + case 5 { + %c20_i32 = arith.constant 20 : i32 + scf.yield %c20_i32 : i32 + } + // CHECK: default { + default { + %c30_i32 = arith.constant 30 : i32 + scf.yield %c30_i32 : i32 + } + + // CHECK: scf.switch %arg0 + scf.switch %arg0 + // CHECK-NEXT: default { + default { + scf.yield + } + + return %0 : i32 +}