diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -701,7 +701,7 @@ if (useInfo.number < entries.size() && entries[useInfo.number].value) { Value result = entries[useInfo.number].value; // Check that the type matches the other uses. - if (result.getType() == type) + if (!type || result.getType() == type) return maybeRecordUse(result); emitError(useInfo.loc, "use of value '") @@ -713,6 +713,13 @@ return nullptr; } + if (!type) { + emitError(useInfo.loc, "forward reference of value '") + .append(useInfo.name, "' requires explicit type specification") + .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)); + return nullptr; + } + // Make sure we have enough slots for this. if (entries.size() <= useInfo.number) entries.resize(useInfo.number + 1); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Reducer/ReductionPatternInterface.h" @@ -1095,6 +1096,34 @@ /*printBlockTerminators=*/false); } +//===----------------------------------------------------------------------===// +// FormatOperandOptionalTypeOp +//===----------------------------------------------------------------------===// + +static bool isDefinedAbove(Value val, Operation *op) { + if (val.isa()) + return true; + + return val.getDefiningOp()->getBlock() == op->getBlock() && + val.getDefiningOp()->isBeforeInBlock(op); +} + +static void printOptionalType(OpAsmPrinter &printer, + FormatOperandOptionalTypeOp op, Type type) { + if (isDefinedAbove(op.operand(), op)) + return; + + printer << ":"; + printer.printType(type); +} + +static ParseResult parseOptionalType(OpAsmParser &parser, Type &type) { + if (parser.parseOptionalColon()) + return success(); + + return parser.parseType(type); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestOpStructs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1753,6 +1753,9 @@ def FormatOperandEOp : FormatOperandBase<"e", [{ $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict }]>; +def FormatOperandOptionalTypeOp : FormatOperandBase<"optional_type", [{ + $buildable `,` $operand custom(type($operand)) attr-dict +}]>; def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> { let successors = (successor VariadicSuccessor:$targets); diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -148,6 +148,12 @@ // CHECK: test.format_operand_e_op %[[I64]], %[[MEMREF]] : i64, memref<1xf64> test.format_operand_e_op %i64, %memref : i64, memref<1xf64> +// CHECK: test.format_operand_optional_type_op %[[I64]], %[[MEMREF]] +test.format_operand_optional_type_op %i64, %memref + +// CHECK: test.format_operand_optional_type_op %[[I64]], %[[MEMREF]] +test.format_operand_optional_type_op %i64, %memref : memref<1xf64> + // CHECK: test.format_variadic_operand %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64 test.format_variadic_operand %i64, %i64, %i64 : i64, i64, i64