diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -73,7 +73,15 @@ "::mlir::Block *", "getSuccessorForOperands", (ins "::mlir::ArrayRef<::mlir::Attribute>":$operands), [{}], /*defaultImplementation=*/[{ return nullptr; }] - > + >, + StaticInterfaceMethod< + /*desc=*/"Specify type constraints over control-flow edges.", + /*retTy=*/"bool", + /*methodName=*/"areCompatibleControlFlowEdgeOperandTypes", + /*args=*/(ins "::mlir::Type":$operandTy, "::mlir::Type":$blockArgTy), + /*methodBody=*/[{}], + /*defaultImplementation=*/"return operandTy == blockArgTy;" + > ]; let verify = [{ @@ -149,7 +157,15 @@ assert(countPerRegion.empty()); countPerRegion.resize(numRegions, kUnknownNumRegionInvocations); }] - > + >, + StaticInterfaceMethod< + /*desc=*/"Specify type constraints over control-flow edges.", + /*retTy=*/"bool", + /*methodName=*/"areCompatibleControlFlowEdgeOperandTypes", + /*args=*/(ins "::mlir::Type":$operandTy, "::mlir::Type":$blockArgTy), + /*methodBody=*/[{}], + /*defaultImplementation=*/"return operandTy == blockArgTy;" + > ]; let verify = [{ diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -62,7 +62,8 @@ // Check the types. auto operandIt = operands->begin(); for (unsigned i = 0; i != operandCount; ++i, ++operandIt) { - if ((*operandIt).getType() != destBB->getArgument(i).getType()) + if (!cast(op).areCompatibleControlFlowEdgeOperandTypes( + (*operandIt).getType(), destBB->getArgument(i).getType())) return op->emitError() << "type mismatch for bb argument #" << i << " of successor #" << succNo; } @@ -135,7 +136,8 @@ llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { Type sourceType = std::get<0>(typesIdx.value()); Type inputType = std::get<1>(typesIdx.value()); - if (sourceType != inputType) { + if (!regionInterface.areCompatibleControlFlowEdgeOperandTypes( + sourceType, inputType)) { InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); return printEdgeName(diag) << ": source type #" << typesIdx.index() << " " << sourceType diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -368,31 +368,6 @@ return } -// ----- - -func @br_mismatch() { -^bb0: - %0:2 = "foo"() : () -> (i1, i17) - // expected-error @+1 {{branch has 2 operands for successor #0, but target block has 1}} - br ^bb1(%0#1, %0#0 : i17, i1) - -^bb1(%x: i17): - return -} - -// ----- - -func @succ_arg_type_mismatch() { -^bb0: - %0 = "getBool"() : () -> i1 - // expected-error @+1 {{type mismatch for bb argument #0 of successor #0}} - br ^bb1(%0 : i1) - -^bb1(%x: i32): - return -} - - // ----- // Test no nested vector. diff --git a/mlir/test/Interfaces/ControlFlowInterfaces/operand-types.mlir b/mlir/test/Interfaces/ControlFlowInterfaces/operand-types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/ControlFlowInterfaces/operand-types.mlir @@ -0,0 +1,112 @@ +// RUN: mlir-opt -allow-unregistered-dialect --split-input-file --verify-diagnostics %s + +//===----------------------------------------------------------------------===// +// Test BranchOpInterface +//===----------------------------------------------------------------------===// + +func @succ_arg_number_mismatch() { +^bb0: + %values:2 = "getValues"() : () -> (i1, i17) + "test.br"(%values#1, %values#0)[^bb1] : (i17, i1) -> () + // expected-error@-1 {{branch has 2 operands for successor #0, but target block has 1}} + +^bb1(%arg1: i17): + return +} + +// ----- + +func @succ_arg_type_mismatch() { +^bb0: + %value = "getValue"() : () -> i1 + "test.br"(%value)[^bb1] : (i1) -> () + // expected-error@-1 {{type mismatch for bb argument #0 of successor #0}} + +^bb1(%arg1: i32): + return +} + +// ----- + +func @succ_type_fixed_to_dynamic() { +^bb0: + %0 = "getValue"() : () -> tensor<1xi32> + "test.br"(%0)[^bb1] : (tensor<1xi32>) -> () + // expected-error@-1 {{type mismatch for bb argument #0 of successor #0}} +^bb1(%arg1: tensor): + return + +^bb2: + "test.cfe.br"(%0)[^bb3] : (tensor<1xi32>) -> () +^bb3(%arg2: tensor): + return +} + +// ----- + +func @succ_type_dynamic_to_fixed() { +^bb0: + %0 = "getValue"() : () -> tensor + "test.cfe.br"(%0)[^bb1] : (tensor) -> () + // expected-error@-1 {{type mismatch for bb argument #0 of successor #0}} +^bb1(%arg3: tensor<1xi32>): + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Test RegionBranchOpInterface +//===----------------------------------------------------------------------===// + +func @succ_arg_type_match() { + %0 = "getValue"() : () -> (i32) + + %1 = test.region_if %0 : i32 -> (i32) then { + ^bb0(%arg1 : i32): + test.region_if_yield %arg1 : i32 + } else { + ^bb0(%arg1 : i32): + test.region_if_yield %arg1 : i32 + } join { + ^bb0(%arg1 : i32): + test.region_if_yield %arg1 : i32 + } + + return +} + +// ----- + +func @succ_type_fixed_to_dynamic() { + %0 = "getValue"() : () -> (memref<1xi32>) + + // expected-error@+1 {{'test.region_if' op along control flow edge from parent operands to Region #1: source type #0 'memref<1xi32>' should match input type #0 'memref'}} + %tmp1 = test.region_if %0 : memref<1xi32> -> (memref) then { + ^bb0(%arg1 : memref<1xi32>): + %true_value = "getValue"(%arg1) : (memref<1xi32>) -> (memref<2xi32>) + test.region_if_yield %true_value : memref<2xi32> + } else { + ^bb0(%arg1 : memref): + %false_value = "getValue"(%arg1) : (memref) -> (memref<3xi32>) + test.region_if_yield %false_value : memref<3xi32> + } join { + ^bb0(%arg1 : memref): + test.region_if_yield %arg1 : memref + } + + %tmp2 = test.cfe.region_if %0 : memref<1xi32> -> (memref) then { + ^bb0(%arg1 : memref<1xi32>): + %true_value = "getValue"(%arg1) : (memref<1xi32>) -> (memref<2xi32>) + test.cfe.region_if_yield %true_value : memref<2xi32> + } else { + ^bb0(%arg1 : memref): + %false_value = "getValue"(%arg1) : (memref) -> (memref<3xi32>) + test.cfe.region_if_yield %false_value : memref<3xi32> + } join { + ^bb0(%arg1 : memref): + test.cfe.region_if_yield %arg1 : memref + } + + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1008,7 +1008,8 @@ // RegionIfOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, RegionIfOp op) { +template +static void printRegionIfOp(OpAsmPrinter &p, RegionIfOpT op) { p << " "; p.printOperands(op.getOperands()); p << ": " << op.getOperandTypes(); @@ -1027,6 +1028,8 @@ /*printBlockTerminators=*/true); } +static void print(OpAsmPrinter &p, RegionIfOp op) { printRegionIfOp(p, op); } + static ParseResult parseRegionIfOp(OpAsmParser &parser, OperationState &result) { SmallVector operandInfos; @@ -1095,6 +1098,50 @@ /*printBlockTerminators=*/false); } +//===----------------------------------------------------------------------===// +// Control-Flow Edge Operand Types Test Ops +//===----------------------------------------------------------------------===// + +Optional +CFEBranchOp::getMutableSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return getTargetOperandsMutable(); +} + +bool CFEBranchOp::areCompatibleControlFlowEdgeOperandTypes(Type operandTy, + Type blockArgTy) { + return isMoreSpecializedOrSame(operandTy, blockArgTy); +} + +static void print(OpAsmPrinter &p, CFERegionIfOp op) { printRegionIfOp(p, op); } + +OperandRange CFERegionIfOp::getSuccessorEntryOperands(unsigned index) { + assert(index < 2 && "invalid region index"); + return getOperands(); +} + +void CFERegionIfOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + // We always branch to the join region. + if (index.hasValue()) { + if (index.getValue() < 2) + regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); + else + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // The then and else regions are the entry regions of this op. + regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); + regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); +} + +bool CFERegionIfOp::areCompatibleControlFlowEdgeOperandTypes(Type operandTy, + Type blockArgTy) { + return isMoreSpecializedOrSame(operandTy, blockArgTy); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestOpStructs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2161,6 +2161,8 @@ } ::mlir::OperandRange getSuccessorEntryOperands(unsigned index); }]; + + let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; } //===----------------------------------------------------------------------===// @@ -2359,4 +2361,59 @@ def : Pat<(TestDefaultStrAttrNoValueOp $value), (TestDefaultStrAttrHasValueOp ConstantStrAttr)>; +//===----------------------------------------------------------------------===// +// Test Control-Flow Edge Operand Types +//===----------------------------------------------------------------------===// + +def CFEBranchOp : TEST_Op<"cfe.br", + [DeclareOpInterfaceMethods, + Terminator]> { + let arguments = (ins Variadic:$targetOperands); + let successors = (successor AnySuccessor:$target); +} + +def CFERegionIfYieldOp : TEST_Op<"cfe.region_if_yield", + [NoSideEffect, ReturnLike, Terminator]> { + let arguments = (ins Variadic:$results); + let assemblyFormat = [{ + $results `:` type($results) attr-dict + }]; +} + +def CFERegionIfOp + : TEST_Op<"cfe.region_if", + [DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"CFERegionIfYieldOp">, + RecursiveSideEffects]> { + let description =[{ + Represents an abstract if-then-else-join pattern. In this context, the then + and else regions jump to the join region, which finally returns to its + parent op. + }]; + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parseRegionIfOp(parser, result); }]; + let arguments = (ins Variadic); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$thenRegion, + AnyRegion:$elseRegion, + AnyRegion:$joinRegion); + let extraClassDeclaration = [{ + ::mlir::Block::BlockArgListType getThenArgs() { + return getBody(0)->getArguments(); + } + ::mlir::Block::BlockArgListType getElseArgs() { + return getBody(1)->getArguments(); + } + ::mlir::Block::BlockArgListType getJoinArgs() { + return getBody(2)->getArguments(); + } + ::mlir::OperandRange getSuccessorEntryOperands(unsigned index); + }]; + + let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; +} + #endif // TEST_OPS