diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -67,7 +67,8 @@ // Infer the shape of the result tensor given the static shapes // and element type of the result tensor. - static Type inferResultType(ArrayRef staticSizes, Type elementType); + static Type inferResultType(ArrayRef staticSizes, Type elementType, + Attribute encoding = {}); // Return true if the size of the tensor is dynamic at `idx` bool isDynamicSize(unsigned idx) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -906,8 +906,8 @@ return op->emitError("expected ") << resultType.getRank() << " sizes values"; - Type expectedType = - InitTensorOp::inferResultType(staticSizes, resultType.getElementType()); + Type expectedType = InitTensorOp::inferResultType( + staticSizes, resultType.getElementType(), resultType.getEncoding()); if (resultType != expectedType) { return op.emitError("specified type ") << resultType << " does not match the inferred type " @@ -917,8 +917,8 @@ } Type InitTensorOp::inferResultType(ArrayRef staticSizes, - Type elementType) { - return RankedTensorType::get(staticSizes, elementType); + Type elementType, Attribute encoding) { + return RankedTensorType::get(staticSizes, elementType, encoding); } namespace { diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -435,15 +435,18 @@ // ----- +#attr = {"foo"} func @init_tensor(%arg0 : index, %arg1 : index) { %0 = linalg.init_tensor [3, 42] : tensor<3x42xf32> %1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32> + %2 = linalg.init_tensor [2, 2] : tensor<2x2xf32, #attr> return } // CHECK-LABEL: func @init_tensor // CHECK: linalg.init_tensor [3, 42] : tensor<3x42xf32> // CHECK: linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32> +// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32, {foo}> // -----