diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -81,12 +81,12 @@ // TODO: Consider generating typedefs for trait member functions if this usage // becomes more common. LogicalResult inferReturnTensorTypes( - function_ref location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &retComponents)> + function_ref< + LogicalResult(MLIRContext *, Location location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &retComponents)> componentTypeFn, - MLIRContext *context, Optional location, ValueRange operands, + MLIRContext *context, Location location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes); @@ -104,9 +104,8 @@ class InferTensorType : public TraitBase { public: static LogicalResult - inferReturnTypes(MLIRContext *context, Optional location, - ValueRange operands, DictionaryAttr attributes, - RegionRange regions, + inferReturnTypes(MLIRContext *context, Location location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { return ::mlir::detail::inferReturnTensorTypes( ConcreteType::inferReturnTypeComponents, context, location, operands, diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -31,15 +31,14 @@ StaticInterfaceMethod< /*desc=*/[{Infer the return types that an op would generate. - The method takes an optional location which, if set, will be used to - report errors on. The operands and attributes correspond to those with - which an Operation would be created (e.g., as used in Operation::create) - and the regions of the op. + The method takes a location which will be used to report errors on. The + operands and attributes correspond to those with which an Operation would + be created (e.g., as used in Operation::create) and the regions of the op. }], /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"inferReturnTypes", /*args=*/(ins "::mlir::MLIRContext *":$context, - "::llvm::Optional<::mlir::Location>":$location, + "::mlir::Location":$location, "::mlir::ValueRange":$operands, "::mlir::DictionaryAttr":$attributes, "::mlir::RegionRange":$regions, @@ -81,10 +80,9 @@ StaticInterfaceMethod< /*desc=*/[{Infer the components of return type of shape containter. - The method takes an optional location which, if set, will be used to - report errors on. The operands and attributes correspond to those with - which an Operation would be created (e.g., as used in Operation::create) - and the regions of the op. + The method takes a location which will be used to report errors on. The + operands and attributes correspond to those with which an Operation would + be created (e.g., as used in Operation::create) and the regions of the op. Unknown (e.g., unranked) shape and nullptrs for element type and attribute may be returned by this function while returning success. E.g., partial @@ -93,7 +91,7 @@ /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"inferReturnTypeComponents", /*args=*/(ins "::mlir::MLIRContext*":$context, - "::mlir::Optional<::mlir::Location>":$location, + "::mlir::Location":$location, "::mlir::ValueRange":$operands, "::mlir::DictionaryAttr":$attributes, "::mlir::RegionRange":$regions, diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -22,12 +22,12 @@ } // namespace mlir LogicalResult mlir::detail::inferReturnTensorTypes( - function_ref location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &retComponents)> + function_ref< + LogicalResult(MLIRContext *, Location location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &retComponents)> componentTypeFn, - MLIRContext *context, Optional location, ValueRange operands, + MLIRContext *context, Location location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { SmallVector retComponents; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -648,7 +648,7 @@ } LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( - MLIRContext *, Optional location, ValueRange operands, + MLIRContext *, Location location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType() != operands[1].getType()) { @@ -661,7 +661,7 @@ } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( - MLIRContext *context, Optional location, ValueRange operands, + MLIRContext *context, Location location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -98,6 +98,9 @@ template static void invokeCreateWithInferredReturnType(Operation *op) { auto *context = op->getContext(); + // Squash error emission/propagation. + DiagnosticEngine::HandlerID diagHandler = + context->getDiagEngine().registerHandler([](Diagnostic &) {}); auto fop = op->getParentOfType(); auto location = UnknownLoc::get(context); OpBuilder b(op); @@ -110,7 +113,7 @@ std::array values = {{fop.getArgument(i), fop.getArgument(j)}}; SmallVector inferredReturnTypes; if (succeeded(OpTy::inferReturnTypes( - context, llvm::None, values, op->getAttrDictionary(), + context, op->getLoc(), values, op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) { OperationState state(location, OpTy::getOperationName()); // TODO: Expand to regions. @@ -119,6 +122,7 @@ } } } + context->getDiagEngine().eraseHandler(diagHandler); } static void reifyReturnShape(Operation *op) {