diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -67,7 +67,8 @@ // Infer the shape of the result tensor given the static shapes // and element type of the result tensor. - static Type inferResultType(ArrayRef staticSizes, Type elementType); + static Type inferResultType(ArrayRef staticSizes, Type elementType, + Attribute encoding = {}); // Return true if the size of the tensor is dynamic at `idx` bool isDynamicSize(unsigned idx) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -906,8 +906,8 @@ return op->emitError("expected ") << resultType.getRank() << " sizes values"; - Type expectedType = - InitTensorOp::inferResultType(staticSizes, resultType.getElementType()); + Type expectedType = InitTensorOp::inferResultType( + staticSizes, resultType.getElementType(), resultType.getEncoding()); if (resultType != expectedType) { return op.emitError("specified type ") << resultType << " does not match the inferred type " @@ -917,8 +917,8 @@ } Type InitTensorOp::inferResultType(ArrayRef staticSizes, - Type elementType) { - return RankedTensorType::get(staticSizes, elementType); + Type elementType, Attribute encoding) { + return RankedTensorType::get(staticSizes, elementType, encoding); } namespace {