diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1712,10 +1712,17 @@ PatternRewriter &rewriter) const override { if (!tensor::canFoldIntoProducerOp(castOp)) return failure(); + auto linalgOp = castOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); + // Cast can be in conditionally reachable region, if which case folding will + // generate invalid code. Only conservatively fold ops in same block for + // now. + if (castOp->getBlock() != linalgOp->getBlock()) + return failure(); + OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(linalgOp); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -846,6 +846,33 @@ // CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]] // CHECK: return %[[MATMUL]], %[[RESULT_CAST]] +// ----- + +func.func private @some_use(%0 : tensor<4x8xf32>) + +func.func @linalgop_with_cond_cast_consumer(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor, %arg3 : i1) -> tensor { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + scf.if %arg3 { + %1 = tensor.cast %0 : tensor to tensor<4x8xf32> + func.call @some_use(%1) : (tensor<4x8xf32>) -> () + } + return %0 : tensor +} + +// Check conditionally reachable cast is not folded into producer. +// CHECK-LABEL: func @linalgop_with_cond_cast_consumer +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: i1) +// CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor, tensor) +// CHECK-SAME: outs(%[[ARG2]] : tensor) -> tensor +// CHECK: scf.if %[[ARG3]] { +// CHECK: %[[CAST:.*]] = tensor.cast %[[RES]] : tensor to tensor<4x8xf32> +// CHECK: func.call @some_use(%[[CAST]]) : (tensor<4x8xf32>) -> () +// CHECK: } +// CHECK: return %[[RES]] : tensor + + // ----- func.func @fold_conv_op_with_cast_consumer(%arg0 : tensor,