diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -836,9 +836,13 @@ sparse_tensor::getSparseTensorEncoding(resultType)) returnedArg = rewriter.create( genericOp.getLoc(), resultType, returnedArg); - else + else { + if (!tensor::CastOp::areCastCompatible(returnedArg.getType(), + resultType)) + return failure(); returnedArg = rewriter.create( genericOp.getLoc(), resultType, returnedArg); + } } returnedArgs.push_back(returnedArg); } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -175,6 +175,24 @@ // ----- +#map = affine_map<() -> ()> +func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor { + %out = linalg.init_tensor [] : tensor + %g = linalg.generic { + indexing_maps = [#map, #map], + iterator_types = [] + } ins(%arg0 : f32) + outs(%out : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } -> (tensor) + return %g : tensor +} +// CHECK-LABEL: func @cant_fold_to_tensor_cast +// CHECK: linalg.generic + +// ----- + #map = affine_map<(d0, d1) -> (d0, d1)> func @keep_not_noop(%arg0 : tensor) -> tensor { %c0 = arith.constant 0 : index