diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -187,15 +187,17 @@ retComponents))) return failure(); for (const auto &shapeAndType : retComponents) { - assert(shapeAndType.getAttribute() == nullptr && "attribute not supported"); - assert(shapeAndType.getElementType() && - "element type required to construct tensor"); - if (shapeAndType.hasRank()) - inferredReturnTypes.push_back(RankedTensorType::get( - shapeAndType.getDims(), shapeAndType.getElementType())); - else + Type element_ty = shapeAndType.getElementType(); + assert(element_ty && "element type required to construct tensor"); + + Attribute attr = shapeAndType.getAttribute(); + if (shapeAndType.hasRank()) { inferredReturnTypes.push_back( - UnrankedTensorType::get(shapeAndType.getElementType())); + RankedTensorType::get(shapeAndType.getDims(), element_ty, attr)); + } else { + assert(attr == nullptr && "attribute not supported"); + inferredReturnTypes.push_back(UnrankedTensorType::get(element_ty)); + } } return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1180,7 +1180,11 @@ int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; auto type = IntegerType::get(context, 17); - inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); + + Attribute encoding; + if (auto ranked_ty = sval.dyn_cast()) + encoding = ranked_ty.getEncoding(); + inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); return success(); } diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir --- a/mlir/test/mlir-tblgen/return-types.mlir +++ b/mlir/test/mlir-tblgen/return-types.mlir @@ -3,19 +3,19 @@ // CHECK-LABEL: testCreateFunctions // This function tests invoking the create method with different inference // methods. The attributes of the ops inside are used to test creation. -func.func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) { +func.func @testCreateFunctions(%arg0 : tensor<10xf32, !test.smpla>, %arg1 : tensor<20xi32>) { // CHECK: "test.no_attributes" - %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32, !test.smpla>, tensor<10xf32, !test.smpla>) -> tensor<10xf32, !test.smpla> // CHECK: "test.op_with_shaped_type_infer_type_if" -// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi17> +// CHECK-SAME: (tensor<10xf32, !test.smpla>, tensor<10xf32, !test.smpla>) -> tensor<10xi17, !test.smpla> // CHECK: "test.op_with_shaped_type_infer_type_if" -// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10xi17> +// CHECK-SAME: (tensor<10xf32, !test.smpla>, tensor<20xi32>) -> tensor<10xi17, !test.smpla> // CHECK: "test.op_with_shaped_type_infer_type_if" -// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20xi17> +// CHECK-SAME: (tensor<20xi32>, tensor<10xf32, !test.smpla>) -> tensor<20xi17> // CHECK: "test.op_with_shaped_type_infer_type_if" // CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi17> // CHECK: "test.op_with_infer_type_if" -// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> +// CHECK-SAME: (tensor<10xf32, !test.smpla>, tensor<10xf32, !test.smpla>) -> tensor<10xf32, !test.smpla> // CHECK: "test.op_with_infer_type_if" // CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi32> return