diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -700,16 +700,35 @@ // If we have already seen a value of this name, return it. if (useInfo.number < entries.size() && entries[useInfo.number].value) { Value result = entries[useInfo.number].value; + // Check that the type matches the other uses. - if (result.getType() == type) - return maybeRecordUse(result); - - emitError(useInfo.loc, "use of value '") - .append(useInfo.name, - "' expects different type than prior uses: ", type, " vs ", - result.getType()) - .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)) - .append("prior use here"); + if (type && result.getType() != type) { + emitError(useInfo.loc, "use of value '") + .append(useInfo.name, + "' expects different type than prior uses: ", type, " vs ", + result.getType()) + .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)) + .append("prior use here"); + return nullptr; + } + + // Explicit type specification can be omitted, only if the Value was already + // defined prior in the same block. + if (!type && result.getParentBlock() != opBuilder.getBlock()) { + emitError(useInfo.loc, "value '") + .append(useInfo.name, "' was defined in separate block and requires " + "explicit type definition") + .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)) + .append("defined here"); + return nullptr; + } + + return maybeRecordUse(result); + } + + if (!type) { + emitError(useInfo.loc, "forward reference of value '") + .append(useInfo.name, "' requires explicit type specification"); return nullptr; } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -1639,3 +1639,31 @@ // ----- func @foo() {} // expected-error {{expected non-empty function body}} + +// ----- + +// expected-note@+1 {{defined here}} +func @foo(%arg0: i64, %arg1: memref<1xf64>) { + br ^bb1 + + ^bb1: + // expected-error@+1 {{value '%arg1' was defined in separate block and requires explicit type definition}} + test.format_operand_optional_type_op %arg0, %arg1 + return +} + +// ----- + +func @foo() { + br ^bb2 + + ^bb1: + // expected-error@+1 {{forward reference of value '%1' requires explicit type specification}} + test.format_operand_optional_type_op %0, %1 + return + + ^bb2: + %0 = arith.constant 0 : i64 + %1 = memref.alloc() : memref<1xf64> + br ^bb1 +} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1470,3 +1470,13 @@ // This is an unregister operation, the printing/parsing is handled by the dialect. // CHECK: test.dialect_custom_printer custom_format test.dialect_custom_printer custom_format + +// Can skip type definition for operands, if they are already defined in the same block +// CHECK-LABEL: func @optional_operand_types +func @optional_operand_types(%arg0: i64, %arg1: memref<1xf64>) { + // CHECK: test.format_operand_optional_type_op %arg0, %arg1 + test.format_operand_optional_type_op %arg0, %arg1 + // CHECK: test.format_operand_optional_type_op %arg0, %arg1 + test.format_operand_optional_type_op %arg0, %arg1 : memref<1xf64> + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Reducer/ReductionPatternInterface.h" @@ -1095,6 +1096,34 @@ /*printBlockTerminators=*/false); } +//===----------------------------------------------------------------------===// +// FormatOperandOptionalTypeOp +//===----------------------------------------------------------------------===// + +static bool isDefinedAbove(Value val, Operation *op) { + if (val.isa()) + return true; + + return val.getDefiningOp()->getBlock() == op->getBlock() && + val.getDefiningOp()->isBeforeInBlock(op); +} + +static void printOptionalType(OpAsmPrinter &printer, + FormatOperandOptionalTypeOp op, Type type) { + if (isDefinedAbove(op.operand(), op)) + return; + + printer << ":"; + printer.printType(type); +} + +static ParseResult parseOptionalType(OpAsmParser &parser, Type &type) { + if (parser.parseOptionalColon()) + return success(); + + return parser.parseType(type); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestOpStructs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1757,6 +1757,9 @@ def FormatOperandEOp : FormatOperandBase<"e", [{ $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict }]>; +def FormatOperandOptionalTypeOp : FormatOperandBase<"optional_type", [{ + $buildable `,` $operand custom(type($operand)) attr-dict +}]>; def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> { let successors = (successor VariadicSuccessor:$targets);