diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2150,6 +2150,48 @@ let assemblyFormat = "`(` operands `)` attr-dict `:` type($args)"; } +// Test inferReturnTypes coupled with regions. +def FormatInferTypeRegionsOp + : TEST_Op<"format_infer_type_regions", [InferTypeOpInterface]> { + let results = (outs Variadic:$outs); + let regions = (region AnyRegion:$region); + let assemblyFormat = "$region attr-dict"; + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + if (regions.empty()) + return ::mlir::failure(); + auto types = regions.front()->getArgumentTypes(); + inferredReturnTypes.assign(types.begin(), types.end()); + return ::mlir::success(); + } + }]; +} + +// Test inferReturnTypes coupled with variadic operands (operand_segment_sizes). +def FormatInferTypeVariadicOperandsOp + : TEST_Op<"format_infer_type_variadic_operands", + [InferTypeOpInterface, AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$a, Variadic:$b); + let results = (outs Variadic:$outs); + let assemblyFormat = "`(` $a `:` type($a) `)` `(` $b `:` type($b) `)` attr-dict"; + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, + ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands, + ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + FormatInferTypeVariadicOperandsOpAdaptor adaptor(operands, attributes); + auto aTypes = adaptor.getA().getTypes(); + auto bTypes = adaptor.getB().getTypes(); + inferredReturnTypes.append(aTypes.begin(), aTypes.end()); + inferredReturnTypes.append(bTypes.begin(), bTypes.end()); + return ::mlir::success(); + } + }]; +} + //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -423,6 +423,16 @@ // CHECK: test.format_infer_type_all_types(%[[I64]], %[[I32]]) : i64, i32 %ignored_res11:2 = test.format_infer_type_all_types(%i64, %i32) : i64, i32 +// CHECK: test.format_infer_type_regions +// CHECK-NEXT: ^bb0(%{{.*}}: {{.*}}, %{{.*}}: {{.*}}): +%ignored_res12:2 = test.format_infer_type_regions { +^bb0(%arg0: i32, %arg1: f32): + "test.terminator"() : () -> () +} + +// CHECK: test.format_infer_type_variadic_operands(%[[I32]], %[[I32]] : i32, i32) (%[[I64]], %[[I64]] : i64, i64) +%ignored_res13:4 = test.format_infer_type_variadic_operands(%i32, %i32 : i32, i32) (%i64, %i64 : i64, i64) + //===----------------------------------------------------------------------===// // Check DefaultValuedStrAttr //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1185,10 +1185,10 @@ // Generate the code to resolve the operand/result types and successors now // that they have been parsed. - genParserTypeResolution(op, body); genParserRegionResolution(op, body); genParserSuccessorResolution(op, body); genParserVariadicSegmentResolution(op, body); + genParserTypeResolution(op, body); body << " return ::mlir::success();\n"; }