diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -235,7 +235,7 @@ } /// Return the `index`-th region iteration argument. BlockArgument getRegionIterArg(unsigned index) { - assert(index < getNumRegionIterArgs() && + assert(index < getNumRegionIterArgs() && "expected an index less than the number of region iter args"); return getBody()->getArguments().drop_front(getNumInductionVars())[index]; } @@ -434,7 +434,7 @@ ``` Example with thread_dim_mapping attribute: - + ```mlir // // Sequential context. @@ -453,10 +453,9 @@ // Implicit synchronization point. // Sequential context. // - ``` Example with privatized tensors: - + ```mlir %t0 = ... %t1 = ... @@ -469,7 +468,6 @@ "some_use"(%t0) "some_use"(%t1) } - ``` }]; let arguments = (ins Variadic:$num_threads, Variadic:$outputs, @@ -527,8 +525,8 @@ return getBody()->getArguments().drop_front(getRank()); } - /// Return the thread indices in the order specified by the - /// thread_dim_mapping attribute. Return failure is + /// Return the thread indices in the order specified by the + /// thread_dim_mapping attribute. Return failure is /// thread_dim_mapping is not a valid permutation. FailureOr> getPermutedThreadIndices(); @@ -988,13 +986,70 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +def SwitchOp : SCF_Op<"switch", [RecursiveMemoryEffects, + SingleBlockImplicitTerminator<"scf::YieldOp">]> { + let summary = "switch-case operation"; + let description = [{ + The `scf.switch` is a control-flow operation that branches to one of the + given regions based on the values of the argument and the cases. The + argument is always of type `index`. + + Example: + + ```mlir + %0 = scf.switch %arg0 : index -> i32 + case 2 { + %1 = arith.constant 10 : i32 + scf.yield %1 : i32 + } + case 5 { + %2 = arith.constant 20 : i32 + scf.yield %2 : i32 + } + default { + %3 = arith.constant 30 : i32 + scf.yield %3 : i32 + } + ``` + }]; + + let arguments = (ins AnySignlessIntegerOrIndex:$arg, + TypedArrayAttrBase:$cases); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$defaultRegion, + VariadicRegion>:$caseRegions); + + let assemblyFormat = [{ + $arg attr-dict `:` type($arg) (`->` type($results)^)? + custom(ref(type($arg)), $cases, $caseRegions) `\n` + `` `default` $defaultRegion + }]; + + let extraClassDeclaration = [{ + /// Get the number of cases. + unsigned getNumCases(); + + /// Get the default region body. + Block &getDefaultBlock(); + + /// Get the body of a case region. + Block &getCaseBlock(unsigned idx); + }]; + + let hasRegionVerifier = 1; +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["ExecuteRegionOp, ForOp", - "IfOp, ParallelOp, WhileOp"]>]> { + ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "ParallelOp", "WhileOp", + "SwitchOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ "scf.yield" yields an SSA value from the SCF dialect op region and diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SmallPtrSet.h" using namespace mlir; using namespace mlir::scf; @@ -3386,6 +3387,96 @@ WhileCmpCond, WhileUnusedResult>(context); } +//===----------------------------------------------------------------------===// +// SwitchOp +//===----------------------------------------------------------------------===// + +/// Parse the case regions and values. +static ParseResult +parseSwitchCases(OpAsmParser &p, Type argType, ArrayAttr &cases, + SmallVectorImpl> &caseRegions) { + SmallVector caseValues; + while (succeeded(p.parseOptionalKeyword("case"))) { + IntegerAttr value; + Region ®ion = + *caseRegions.emplace_back(std::make_unique()).get(); + if (p.parseAttribute(value, argType) || + p.parseRegion(region, /*arguments=*/{})) + return failure(); + caseValues.push_back(value); + } + cases = p.getBuilder().getArrayAttr(caseValues); + return success(); +} + +/// Print the case regions and values. +static void printSwitchCases(OpAsmPrinter &p, Operation *op, Type argType, + ArrayAttr cases, RegionRange caseRegions) { + for (auto [value, region] : llvm::zip(cases, caseRegions)) { + p.printNewline(); + p << "case "; + p.printAttributeWithoutType(value); + p << ' '; + p.printRegion(*region, /*printEntryBlockArgs=*/false); + } +} + +LogicalResult scf::SwitchOp::verifyRegions() { + if (getCases().size() != getCaseRegions().size()) { + return emitOpError("has ") + << getCaseRegions().size() << " case regions but " + << getCases().size() << " case values"; + } + + SmallPtrSet valueSet; + for (auto value : getCases().getAsRange()) { + if (!valueSet.insert(value).second) + return emitOpError("has duplicate case value: ") << value; + if (value.getType() != getArg().getType()) + return emitOpError("expected all case values to be of type ") + << getArg().getType() << " but got " << value; + } + + auto verifyRegion = [&](Region ®ion, const Twine &name) -> LogicalResult { + auto yield = cast(region.front().getTerminator()); + if (yield.getNumOperands() != getNumResults()) { + return (emitOpError("expected each region to return ") + << getNumResults() << " values, but " << name << " returns " + << yield.getNumOperands()) + .attachNote(yield.getLoc()) + << "see yield operation here"; + } + for (auto [idx, result, operand] : + llvm::zip(llvm::seq(0, getNumResults()), getResultTypes(), + yield.getOperandTypes())) { + if (result == operand) + continue; + return (emitOpError("expected result #") + << idx << " of each region to be " << result) + .attachNote(yield.getLoc()) + << name << " returns " << operand << " here"; + } + return success(); + }; + + if (failed(verifyRegion(getDefaultRegion(), "default region"))) + return failure(); + for (auto &[idx, caseRegion] : llvm::enumerate(getCaseRegions())) + if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx)))) + return failure(); + + return success(); +} + +unsigned scf::SwitchOp::getNumCases() { return getCases().size(); } + +Block &scf::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); } + +Block &scf::SwitchOp::getCaseBlock(unsigned idx) { + assert(idx < getNumCases() && "case index out-of-bounds"); + return getCaseRegions()[idx].front(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -428,7 +428,7 @@ func.func @yield_invalid_parent_op() { "my.op"() ({ - // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.parallel, scf.while'}} + // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.parallel, scf.while, scf.switch'}} scf.yield }) : () -> () return @@ -572,3 +572,69 @@ } return } + +// ----- + +func.func @switch_wrong_case_count(%arg0: index) { + // expected-error @below {{'scf.switch' op has 0 case regions but 1 case values}} + "scf.switch"(%arg0) ({ + scf.yield + }) {cases = [1 : index]} : (index) -> () + return +} + +// ----- + +func.func @switch_duplicate_case(%arg0: index) { + // expected-error @below {{'scf.switch' op has duplicate case value: 0}} + scf.switch %arg0 : index + case 0 { + scf.yield + } + case 0 { + scf.yield + } + default { + scf.yield + } + return +} + +// ----- + +func.func @switch_wrong_types(%arg0: index) { + // expected-error @below {{'scf.switch' op expected each region to return 0 values, but default region returns 1}} + scf.switch %arg0 : index + default { + // expected-note @below {{see yield operation here}} + scf.yield %arg0 : index + } + return +} + +// ----- + +func.func @switch_wrong_types(%arg0: index, %arg1: i32) { + // expected-error @below {{'scf.switch' op expected result #0 of each region to be 'index'}} + scf.switch %arg0 : index -> index + case 0 { + // expected-note @below {{case region #0 returns 'i32' here}} + scf.yield %arg1 : i32 + } + default { + scf.yield %arg0 : index + } + return +} + +// ----- + +func.func @switch_wrong_types(%arg0: index) { + // expected-error @below {{'scf.switch' op expected all case values to be of type 'index' but got 1 : i32}} + "scf.switch"(%arg0) ({ + scf.yield + }, { + scf.yield +}) {cases = [1 : i32]} : (index) -> () + return +} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -346,3 +346,36 @@ } {thread_dim_mapping = [42]} return } + +// CHECK-LABEL: @switch +func.func @switch(%arg0: index) -> i32 { + // CHECK: %{{.*}} = scf.switch %arg0 : index -> i32 + %0 = scf.switch %arg0 : index -> i32 + // CHECK-NEXT: case 2 { + case 2 { + // CHECK-NEXT: arith.constant + %c10_i32 = arith.constant 10 : i32 + // CHECK-NEXT: scf.yield %{{.*}} : i32 + scf.yield %c10_i32 : i32 + // CHECK-NEXT: } + } + // CHECK-NEXT: case 5 { + case 5 { + %c20_i32 = arith.constant 20 : i32 + scf.yield %c20_i32 : i32 + } + // CHECK: default { + default { + %c30_i32 = arith.constant 30 : i32 + scf.yield %c30_i32 : i32 + } + + // CHECK: scf.switch %arg0 : index + scf.switch %arg0 : index + // CHECK-NEXT: default { + default { + scf.yield + } + + return %0 : i32 +}