diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -78,11 +78,31 @@ /// Range of values and shapes (corresponding effectively to Shapes dialect's /// ValueShape type concept). -using ValueShapeRange = ValueRange; +class ValueShapeRange : public ValueRange::RangeBaseT { +public: + ValueShapeRange(ValueRange values) : RangeBaseT(values) {} + template ::value>> + ValueShapeRange(Arg &&arg) + : ValueShapeRange(ValueRange(std::forward(arg))) {} + ValueShapeRange(const std::initializer_list &values) + : ValueShapeRange(ValueRange(values)) {} + + /// Returns the types of the values within this range. + /// Note: This returns only the types of Values in the ValueRange and not a + /// more refined type. + using type_iterator = ValueTypeIterator; + using type_range = ValueTypeRange; + type_range getTypes() const { return {begin(), end()}; } + auto getType() const { return getTypes(); } + + /// Returns the Values in the ValueRange. + ValueRange getValues() const { return ValueRange(begin(), end()); }; +}; namespace detail { -// Helper function to infer return tensor returns types given element and shape -// inference function. +// Helper function to infer return tensor returns types given element and +// shape inference function. // // TODO: Consider generating typedefs for trait member functions if this usage // becomes more common. diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -736,7 +736,8 @@ #define REDUCE_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::llvm::Optional location, \ - ValueRange operands, DictionaryAttr attributes, RegionRange regions, \ + ValueShapeRange operands, DictionaryAttr attributes, \ + RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ return ReduceInferReturnTypes(operands[0], \ attributes.get("axis").cast(), \ @@ -802,7 +803,8 @@ #define NARY_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::llvm::Optional location, \ - ValueRange operands, DictionaryAttr attributes, RegionRange regions, \ + ValueShapeRange operands, DictionaryAttr attributes, \ + RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ return NAryInferReturnTypes(operands, inferredReturnShapes); \ } @@ -892,7 +894,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); - Conv2DOp::Adaptor adaptor(operands); + Conv2DOp::Adaptor adaptor(operands.getValues()); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -953,7 +955,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(5, ShapedType::kDynamicSize); - Conv2DOp::Adaptor adaptor(operands); + Conv2DOp::Adaptor adaptor(operands.getValues()); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1040,7 +1042,7 @@ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); - DepthwiseConv2DOp::Adaptor adaptor(operands); + DepthwiseConv2DOp::Adaptor adaptor(operands.getValues()); int32_t inputWidth = ShapedType::kDynamicSize; int32_t inputHeight = ShapedType::kDynamicSize; @@ -1114,7 +1116,7 @@ MLIRContext *context, ::llvm::Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - TransposeConv2DOp::Adaptor adaptor(operands); + TransposeConv2DOp::Adaptor adaptor(operands.getValues()); llvm::SmallVector outputShape; getI64Values(attributes.get("out_shape").cast(), outputShape); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -785,7 +785,7 @@ DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. - auto operandType = *operands.getTypes().begin(); + auto operandType = operands.front().getType(); auto sval = operandType.dyn_cast(); if (!sval) { return emitOptionalError(location, "only shaped type operands allowed");