diff --git a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h --- a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h +++ b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H #define MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/ViewLikeInterface.h" @@ -22,7 +23,8 @@ /// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets /// - Combined sizes = consumer_sizes /// - Combined strides = producer_strides * consumer_strides -// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate. +// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or +// deprecate. LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef producerOffsets, @@ -38,7 +40,8 @@ /// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use /// when combining a `producer` slice op **into** a `consumer` slice op. -// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate. +// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or +// deprecate. LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer, @@ -48,8 +51,8 @@ SmallVector &combinedSizes, SmallVector &combinedStrides); -/// Given the 'indicesVals' of a load/store operation operating on an op with -/// offsets and strides, return the combined indices. +/// Given the 'consumerIndices' of a load/store operation operating on an op +/// with offsets and strides, return the combined indices. /// /// For example, using `memref.load` and `memref.subview` as an illustration: /// @@ -64,13 +67,37 @@ /// /// ``` /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : -/// memref<12x42xf32> +/// memref<12x42xf32>å /// ``` -void resolveSourceIndicesOffsetsAndStrides( - RewriterBase &rewriter, Location loc, ArrayRef mixedOffsets, - ArrayRef mixedStrides, - const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals, - SmallVectorImpl &sourceIndices); +void resolveIndicesIntoOpWithOffsetsAndStrides( + RewriterBase &rewriter, Location loc, + ArrayRef mixedSourceOffsets, + ArrayRef mixedSourceStrides, + const llvm::SmallBitVector &rankReducedDims, + ArrayRef consumerIndices, + SmallVectorImpl &resolvedIndices); + +inline void resolveIndicesIntoOpWithOffsetsAndStrides( + RewriterBase &rewriter, Location loc, + ArrayRef mixedSourceOffsets, + ArrayRef mixedSourceStrides, + const llvm::SmallBitVector &rankReducedDims, ValueRange consumerIndices, + SmallVectorImpl &resolvedIndices) { + return resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, loc, mixedSourceOffsets, mixedSourceStrides, rankReducedDims, + getAsOpFoldResult(consumerIndices), resolvedIndices); +} + +/// Given `sourceSizes`, `destSizes` and information about which dimensions are +/// dropped by the source: `rankReducedSourceDims`, compute the resolved sizes +/// that correspond to dest_op(source_op). +/// In practice, this amounts to filtering by `rankReducedSourceDims` and taking +/// from `sourceSizes` if a dimension is dropped, otherwise taking from +/// `destSizes`. +void resolveSizesIntoOpWithSizes( + ArrayRef sourceSizes, ArrayRef destSizes, + const llvm::SmallBitVector &rankReducedSourceDims, + SmallVectorImpl &resolvedSizes); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1963,13 +1963,13 @@ }]; let builders = [ - // Build a SubViewOp with mixed static and dynamic entries and custom - // result type. If the type passed is nullptr, it is inferred. + // Build a SubViewOp with mixed static and dynamic entries and inferred + // result type. OpBuilder<(ins "Value":$source, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, CArg<"ArrayRef", "{}">:$attrs)>, - // Build a SubViewOp with mixed static and dynamic entries and inferred - // result type. + // Build a SubViewOp with mixed static and dynamic entries and custom + // result type. If the type passed is nullptr, it is inferred. OpBuilder<(ins "MemRefType":$resultType, "Value":$source, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -823,17 +823,20 @@ }]; let builders = [ - // Build a InsertSliceOp with mixed static and dynamic entries. + // Build a InsertSliceOp with mixed static and dynamic entries and inferred + // result type. OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, CArg<"ArrayRef", "{}">:$attrs)>, - // Build a InsertSliceOp with dynamic entries. + // Build a InsertSliceOp with dynamic entries and inferred + // result type. OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)>, // Build an InsertSliceOp with mixed static and dynamic entries packed in - // a Range vector. + // a Range vector and inferred + // result type. OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$ranges, CArg<"ArrayRef", "{}">:$attrs)> @@ -1450,6 +1453,10 @@ /// Return the OpResult of the enclosing ForallOp that is /// corresponding to this ParallelInsertSliceOp. OpResult getTiedOpResult(); + + /// Return the dimensions of the dest that are omitted to insert a source + /// when the result is rank-extended. + llvm::SmallBitVector getDroppedDims(); }]; let builders = [ diff --git a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp @@ -77,32 +77,49 @@ combinedOffsets, combinedSizes, combinedStrides); } -void mlir::resolveSourceIndicesOffsetsAndStrides( - RewriterBase &rewriter, Location loc, ArrayRef mixedOffsets, - ArrayRef mixedStrides, - const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals, - SmallVectorImpl &sourceIndices) { +void mlir::resolveIndicesIntoOpWithOffsetsAndStrides( + RewriterBase &rewriter, Location loc, + ArrayRef mixedSourceOffsets, + ArrayRef mixedSourceStrides, + const llvm::SmallBitVector &rankReducedDims, + ArrayRef consumerIndices, + SmallVectorImpl &resolvedIndices) { OpFoldResult zero = rewriter.getIndexAttr(0); // For each dimension that is rank-reduced, add a zero to the indices. int64_t indicesDim = 0; SmallVector indices; - for (auto dim : llvm::seq(0, mixedOffsets.size())) { + for (auto dim : llvm::seq(0, mixedSourceOffsets.size())) { OpFoldResult ofr = - (rankReducedDims.test(dim)) ? zero : indicesVals[indicesDim++]; + (rankReducedDims.test(dim)) ? zero : consumerIndices[indicesDim++]; indices.push_back(ofr); } - sourceIndices.resize(indices.size()); - sourceIndices.clear(); + resolvedIndices.resize(indices.size()); + resolvedIndices.clear(); for (auto [offset, index, stride] : - llvm::zip_equal(mixedOffsets, indices, mixedStrides)) { + llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) { AffineExpr off, idx, str; bindSymbols(rewriter.getContext(), off, idx, str); OpFoldResult ofr = makeComposedFoldedAffineApply( rewriter, loc, AffineMap::get(0, 3, off + idx * str), {offset, index, stride}); - sourceIndices.push_back( + resolvedIndices.push_back( getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); } } + +void mlir::resolveSizesIntoOpWithSizes( + ArrayRef sourceSizes, ArrayRef destSizes, + const llvm::SmallBitVector &rankReducedSourceDims, + SmallVectorImpl &resolvedSizes) { + int64_t dim = 0; + int64_t srcRank = sourceSizes.size(); + for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) { + if (rankReducedSourceDims[srcDim]) { + resolvedSizes.push_back(sourceSizes[srcDim]); + continue; + } + resolvedSizes.push_back(destSizes[dim++]); + } +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -248,48 +248,38 @@ LogicalResult matchAndRewrite(memref::SubViewOp subView, PatternRewriter &rewriter) const override { - Location loc = subView.getLoc(); auto srcSubView = subView.getSource().getDefiningOp(); if (!srcSubView) return failure(); - int64_t srcRank = srcSubView.getSourceType().getRank(); - - // TODO: Only stride 1 is supported. - for (auto s : {subView.getMixedStrides(), srcSubView.getMixedStrides()}) - if (!llvm::all_of( - s, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) - return failure(); - - // Get original offsets and sizes. - SmallVector offsets = subView.getMixedOffsets(); - SmallVector srcOffsets = srcSubView.getMixedOffsets(); - SmallVector sizes = subView.getMixedSizes(); - SmallVector srcSizes = srcSubView.getMixedSizes(); - - // Compute new offsets and sizes. - llvm::SmallBitVector srcReducedDims = srcSubView.getDroppedDims(); - SmallVector newOffsets, newSizes; - int64_t dim = 0; - for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) { - if (srcReducedDims[srcDim]) { - // Dim is reduced in srcSubView. - assert(isConstantIntValue(srcSizes[srcDim], 1) && "expected size 1"); - newOffsets.push_back(srcOffsets[srcDim]); - newSizes.push_back(srcSizes[srcDim]); - continue; - } - AffineExpr sym0, sym1; - bindSymbols(subView.getContext(), sym0, sym1); - newOffsets.push_back(makeComposedFoldedAffineApply( - rewriter, loc, sym0 + sym1, {srcOffsets[srcDim], offsets[dim]})); - newSizes.push_back(sizes[dim]); - ++dim; + + // TODO: relax unit stride assumption. + if (!subView.hasUnitStride()) { + return rewriter.notifyMatchFailure(subView, "requires unit strides"); + } + if (!srcSubView.hasUnitStride()) { + return rewriter.notifyMatchFailure(srcSubView, "requires unit strides"); } + // Resolve sizes according to dropped dims. + SmallVector resolvedSizes; + llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims(); + resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(), + subView.getMixedSizes(), srcDroppedDims, + resolvedSizes); + + // Resolve offsets according to source offsets and strides. + SmallVector resolvedOffsets; + resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, subView.getLoc(), srcSubView.getMixedOffsets(), + srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(), + resolvedOffsets); + // Replace original op. rewriter.replaceOpWithNewOp( - subView, subView.getType(), srcSubView.getSource(), newOffsets, - newSizes, srcSubView.getMixedStrides()); + subView, subView.getType(), srcSubView.getSource(), + getAsOpFoldResult(resolvedOffsets), resolvedSizes, + srcSubView.getMixedStrides()); + return success(); } }; @@ -372,7 +362,7 @@ indices.assign(expandedIndices.begin(), expandedIndices.end()); } SmallVector sourceIndices; - resolveSourceIndicesOffsetsAndStrides( + resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(), subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, sourceIndices); @@ -492,7 +482,7 @@ indices.assign(expandedIndices.begin(), expandedIndices.end()); } SmallVector sourceIndices; - resolveSourceIndicesOffsetsAndStrides( + resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(), subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, sourceIndices); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3086,6 +3086,10 @@ InsertSliceOpSourceCastInserter>(context); } +llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() { + return ::getDroppedDims(getSourceType().getShape(), getMixedSizes()); +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" @@ -21,6 +22,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" +#include namespace mlir { namespace tensor { @@ -98,7 +100,7 @@ SmallVector indices(readOp.getIndices().begin(), readOp.getIndices().end()); SmallVector sourceIndices; - resolveSourceIndicesOffsetsAndStrides( + resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(), indices, sourceIndices); @@ -130,7 +132,7 @@ SmallVector indices(writeOp.getIndices().begin(), writeOp.getIndices().end()); SmallVector sourceIndices; - resolveSourceIndicesOffsetsAndStrides( + resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices, sourceIndices); @@ -145,9 +147,80 @@ return success(); } +template +struct InsertSliceOfInsertSliceFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy insertSliceOp, + PatternRewriter &rewriter) const override { + auto sourceInsertSliceOp = + insertSliceOp.getSource() + .template getDefiningOp(); + if (!sourceInsertSliceOp) + return failure(); + + // TODO: relax unit stride assumption where possible. + if (!insertSliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure(insertSliceOp, + "requires unit strides"); + } + if (!sourceInsertSliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure(sourceInsertSliceOp, + "requires unit strides"); + } + + int64_t srcDim = 0; + llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims(); + for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) { + if (droppedDims[d]) + continue; + if (insertSliceOp.getMixedSizes()[d] != + sourceInsertSliceOp.getMixedSizes()[srcDim++]) { + return rewriter.notifyMatchFailure( + sourceInsertSliceOp, + "requires matching sizes to fold, otherwise a copy is needed"); + } + } + + // Resolve sizes according to dropped dims. + SmallVector resolvedSizes; + resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(), + sourceInsertSliceOp.getMixedSizes(), + droppedDims, resolvedSizes); + + // If we are inside an InParallel region, temporarily set the insertion + // point outside: only tensor.parallel_insert_slice ops are allowed in + // there. + if (std::is_same_v) { + rewriter.setInsertionPoint( + insertSliceOp->template getParentOfType()); + } + + // Resolve offsets according to source offsets and strides. + SmallVector resolvedOffsets; + resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(), + insertSliceOp.getMixedStrides(), droppedDims, + sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets); + + // Reset the insertion point. + rewriter.setInsertionPoint(insertSliceOp); + // Replace original op. + rewriter.replaceOpWithNewOp( + insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(), + getAsOpFoldResult(resolvedOffsets), resolvedSizes, + insertSliceOp.getMixedStrides()); + + return success(); + } +}; + void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + InsertSliceOfTransferWriteOpFolder, + InsertSliceOfInsertSliceFolder, + InsertSliceOfInsertSliceFolder>( + patterns.getContext()); } //===----------------------------------------------------------------------===// // Pass registration diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir --- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir +++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -fold-tensor-subset-ops -split-input-file %s | FileCheck %s +// RUN: mlir-opt -fold-tensor-subset-ops -split-input-file --allow-unregistered-dialect %s | FileCheck %s func.func @fold_vector_transfer_read_with_rank_reduced_extract_slice( %arg0 : tensor, @@ -260,3 +260,125 @@ %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor return %1 : tensor } + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)> +// CHECK-LABEL: func @insert_slice_of_insert_slice( +// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor +// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<1x14xf32> +// CHECK-SAME: %[[pos:[0-9a-z]*]]: index +// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]] +// CHECK: tensor.insert_slice %[[t]] into %[[r1]][4, %[[add]]] [1, 1] [1, 1] : tensor into tensor<1x14xf32> +func.func @insert_slice_of_insert_slice(%t: tensor, %r0: tensor<1x1xf32>, %r1: tensor<1x14xf32>, %pos: index) + -> tensor<1x14xf32> +{ + %0 = tensor.insert_slice %t into %r0[1, 2] [1, 1] [1, 1] + : tensor into tensor<1x1xf32> + %1 = tensor.insert_slice %0 into %r1[3, %pos] [1, 1] [1, 1] + : tensor<1x1xf32> into tensor<1x14xf32> + return %1 : tensor<1x14xf32> +} + +// ----- + +// CHECK-LABEL: func @insert_slice_of_insert_slice( +// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor +// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<1x14xf32> +// CHECK-SAME: %[[pos:[0-9a-z]*]]: index +// CHECK: tensor.insert_slice %[[t]] into %[[r1]][5, %[[pos]]] [1, 1] [1, 1] : tensor into tensor<1x14xf32> +func.func @insert_slice_of_insert_slice(%t: tensor, %r0: tensor<1xf32>, %r1: tensor<1x14xf32>, %pos: index) + -> tensor<1x14xf32> +{ + %0 = tensor.insert_slice %t into %r0[2] [1] [1] + : tensor into tensor<1xf32> + %1 = tensor.insert_slice %0 into %r1[3, %pos] [1, 1] [1, 1] + : tensor<1xf32> into tensor<1x14xf32> + return %1 : tensor<1x14xf32> +} + +// ----- + +// This test fails to fold because the size `4` and `%pos` do not match: +// this requires a copy +// CHECK-LABEL: func @fail_insert_slice_of_insert_slice( +// CHECK: tensor.insert_slice +// CHECK: tensor.insert_slice +func.func @fail_insert_slice_of_insert_slice( + %t: tensor<4xf32>, %r0: tensor, %r1: tensor, %pos: index) + -> tensor +{ + %0 = tensor.insert_slice %t into %r0[%pos] [4] [1] + : tensor<4xf32> into tensor + %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1] + : tensor into tensor + return %1 : tensor +} + +// ----- + +// Here the sizes are the same and the folding occurs properly. +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic( +// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]] +// CHECK: tensor.insert_slice %[[t]] into %[[r1]][%[[add]], 423] [%[[pos]], 1] [1, 1] : tensor into tensor +func.func @insert_slice_of_insert_slice_dynamic( + %t: tensor, %r0: tensor, %r1: tensor, %pos: index) + -> tensor +{ + %0 = tensor.insert_slice %t into %r0[%pos] [%pos] [1] + : tensor into tensor + %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1] + : tensor into tensor + return %1 : tensor +} + +// ----- + +// Here the sizes are the same and the folding occurs properly. +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic( +// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]] +// CHECK: tensor.insert_slice %[[t]] into %[[r1]][%[[add]], 423] [%[[pos]], 1] [1, 1] : tensor into tensor +func.func @insert_slice_of_insert_slice_dynamic( + %t: tensor, %r0: tensor, %r1: tensor, %pos: index) + -> tensor +{ + %0 = tensor.insert_slice %t into %r0[%pos] [%pos] [1] + : tensor into tensor + %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1] + : tensor into tensor + return %1 : tensor +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @parallel_insert_slice_of_insert_slice_dynamic( +// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<12x34xf32> +// CHECK-SAME: %[[o0:[0-9a-z]*]]: index +// CHECK-SAME: %[[o1:[0-9a-z]*]]: index +// CHECK-SAME: %[[sz0:[0-9a-z]*]]: index +// CHECK-SAME: %[[sz1:[0-9a-z]*]]: index +func.func @parallel_insert_slice_of_insert_slice_dynamic( + %t: tensor<12x34xf32>, %o0: index, %o1: index, %sz0: index, %sz1: index) + -> tensor<12x34xf32>{ + + // CHECK: scf.forall {{.*}} shared_outs(%[[out:.*]] = %[[t]] + %0 = scf.forall (%arg0, %arg1) in (27, 8) shared_outs(%arg2 = %t) -> (tensor<12x34xf32>) { + // CHECK: %[[tt:.*]] = "make_me_a_tensor"() : () -> tensor + %tt = "make_me_a_tensor"() : () -> tensor + %tt2 = "make_me_another_tensor"() : () -> tensor + %inserted_slice = tensor.insert_slice %tt into %tt2[%o1, 0] [%sz0, %sz1] [1, 1] : tensor into tensor + + // CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[o0]], %[[o1]]] + // CHECK: scf.forall.in_parallel + // CHECK: tensor.parallel_insert_slice %[[tt]] into %[[out]][%[[add]], %[[o1]]] [%[[sz0]], %[[sz1]]] [1, 1] + // CHECK-SAME: : tensor into tensor<12x34xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %inserted_slice into %arg2[%o0, %o1] [%sz0, %sz1] [1, 1] + : tensor into tensor<12x34xf32> + } + } + return %0: tensor<12x34xf32> +}