diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -564,6 +564,7 @@ ]; let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1276,6 +1276,40 @@ }; } // namespace +/// Fold a parallel_insert_slice source coming from a tensor.cast op. +/// +/// Example: +/// ``` +/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { +/// %1 = compute_some_tensor() : tensor<64xf32> +/// %2 = tensor.cast %1 : tensor<64xf32> to tensor +/// scf.foreach_thread.perform_concurrently { +/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] : +/// tensor into tensor<128xf32> +/// } +/// } +/// ``` +/// +/// is folded into: +/// ``` +/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) { +/// %1 = compute_some_tensor() : tensor<64xf32> +/// scf.foreach_thread.perform_concurrently { +/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] : +/// tensor<64xf32> into tensor<128xf32> +/// } +/// } +/// ``` +LogicalResult +ParallelInsertSliceOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + auto sourceCast = getSource().getDefiningOp(); + if (!sourceCast) + return failure(); + getSourceMutable().assign(sourceCast.getSource()); + return success(); +} + void ParallelInsertSliceOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add(context); diff --git a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir --- a/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir +++ b/mlir/test/Dialect/SCF/foreach-thread-canonicalization.mlir @@ -26,9 +26,8 @@ linalg.yield %14 : f32 } -> tensor - // TODO: canonicalize this cast away. - // CHECK: %[[dyn_casted:.*]] = tensor.cast %{{.*}} : tensor<64xf32> to tensor - // CHECK: scf.foreach_thread.parallel_insert_slice %[[dyn_casted:.*]] into %{{.*}}[%{{.*}}] [64] [1] : tensor into tensor<128xf32> + // CHECK-NOT: tensor.cast + // CHECK: scf.foreach_thread.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [64] [1] : tensor<64xf32> into tensor<128xf32> scf.foreach_thread.perform_concurrently { scf.foreach_thread.parallel_insert_slice %13 into %3[%9] [%10] [1] : tensor into tensor<128xf32> }