diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1645,12 +1645,20 @@ ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); - static Type inferRankReducedResultType(unsigned resultRank, + + /// A rank-reducing result type can be inferred from the desired result + /// shape. Only the layout map is inferred. + /// + /// Note: The result shape cannot be inferred with just the result rank and + /// and the desired sizes. In case there are more "ones" among the sizes + /// than the difference in source/result rank, it is not clear which dims of + /// size one should be dropped. + static Type inferRankReducedResultType(ArrayRef resultShape, MemRefType sourceMemRefType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); - static Type inferRankReducedResultType(unsigned resultRank, + static Type inferRankReducedResultType(ArrayRef resultShape, MemRefType sourceMemRefType, ArrayRef staticOffsets, ArrayRef staticSizes, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp @@ -215,25 +215,10 @@ /*rewriteFunc=*/ [](OpBuilder &b, Location loc, OpOperand &operand) { auto insertOp = cast(operand.getOwner()); - // Expand offsets, sizes and strides to the full rank to handle the - // rank-reducing case. - SmallVector mixedOffsets = insertOp.getMixedOffsets(); - SmallVector mixedSizes = insertOp.getMixedSizes(); - SmallVector mixedStrides = insertOp.getMixedStrides(); - OffsetSizeAndStrideOpInterface::expandToRank( - insertOp.getDest(), mixedOffsets, mixedSizes, mixedStrides, - [&](Value target, int64_t dim) -> OpFoldResult { - auto shapedType = target.getType().cast(); - if (shapedType.isDynamicDim(dim)) - return b.create(loc, target, dim).getResult(); - return b.getIndexAttr(shapedType.getDimSize(dim)); - }); - auto t = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( - insertOp.getSourceType().getRank(), - insertOp.getDest().getType().cast(), mixedOffsets, - mixedSizes, mixedStrides); auto extractOp = b.create( - loc, t, insertOp.getDest(), mixedOffsets, mixedSizes, mixedStrides); + loc, insertOp.getSourceType(), insertOp.getDest(), + insertOp.getMixedOffsets(), insertOp.getMixedSizes(), + insertOp.getMixedStrides()); return extractOp.getResult(); }); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2145,7 +2145,7 @@ staticSizes, staticStrides); } -Type SubViewOp::inferRankReducedResultType(unsigned resultRank, +Type SubViewOp::inferRankReducedResultType(ArrayRef resultShape, MemRefType sourceRankedTensorType, ArrayRef offsets, ArrayRef sizes, @@ -2153,27 +2153,26 @@ auto inferredType = inferResultType(sourceRankedTensorType, offsets, sizes, strides) .cast(); - assert(inferredType.getRank() >= resultRank && "expected "); - int rankDiff = inferredType.getRank() - resultRank; - if (rankDiff > 0) { - auto shape = inferredType.getShape(); - llvm::SmallBitVector dimsToProject = - getPositionsOfShapeOne(rankDiff, shape); - SmallVector projectedShape; - for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) - if (!dimsToProject.test(pos)) - projectedShape.push_back(shape[pos]); - - AffineMap map = - getProjectedMap(inferredType.getLayout().getAffineMap(), dimsToProject); - inferredType = - MemRefType::get(projectedShape, inferredType.getElementType(), map, - inferredType.getMemorySpace()); - } - return inferredType; -} - -Type SubViewOp::inferRankReducedResultType(unsigned resultRank, + assert(inferredType.getRank() >= resultShape.size() && "expected "); + if (inferredType.getRank() == resultShape.size()) + return inferredType; + + // Compute which dimensions are dropped. + Optional> dimsToProject = + computeRankReductionMask(inferredType.getShape(), resultShape); + assert(dimsToProject.hasValue() && "invalid rank reduction"); + llvm::SmallBitVector dimsToProjectVector(inferredType.getRank()); + for (unsigned dim : *dimsToProject) + dimsToProjectVector.set(dim); + + // Compute layout map and result type. + AffineMap map = getProjectedMap(inferredType.getLayout().getAffineMap(), + dimsToProjectVector); + return MemRefType::get(resultShape, inferredType.getElementType(), map, + inferredType.getMemorySpace()); +} + +Type SubViewOp::inferRankReducedResultType(ArrayRef resultShape, MemRefType sourceRankedTensorType, ArrayRef offsets, ArrayRef sizes, @@ -2187,9 +2186,10 @@ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamicStrideOrOffset); return SubViewOp::inferRankReducedResultType( - resultRank, sourceRankedTensorType, staticOffsets, staticSizes, + resultShape, sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); } + // Build a SubViewOp with mixed static and dynamic entries and custom result // type. If the type passed is nullptr, it is inferred. void SubViewOp::build(OpBuilder &b, OperationState &result, diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -44,7 +44,7 @@ } builder.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getRank(), val.getType().cast(), + subviewUse.getType().getShape(), val.getType().cast(), extractFromI64ArrayAttr(subviewUse.static_offsets()), extractFromI64ArrayAttr(subviewUse.static_sizes()), extractFromI64ArrayAttr(subviewUse.static_strides())); @@ -136,7 +136,7 @@ sizes.push_back(builder.getIndexAttr(size)); auto dstMemref = memref::SubViewOp::inferRankReducedResultType( - allocOp.getType().getRank(), newMemref, offsets, sizes, strides) + allocOp.getType().getShape(), newMemref, offsets, sizes, strides) .cast(); Value subview = builder.create(loc, dstMemref, newAlloc, offsets, sizes, strides); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -278,36 +278,24 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto extractSliceOp = cast(op); + SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); + SmallVector mixedSizes = extractSliceOp.getMixedSizes(); + SmallVector mixedStrides = extractSliceOp.getMixedStrides(); Location loc = extractSliceOp.getLoc(); - // Even if this op was decided to bufferize out-of-place, do not insert the - // buffer copy yet. This is done later in this function. + // Get source buffer. FailureOr srcMemref = getBuffer(rewriter, extractSliceOp.getSource(), options); if (failed(srcMemref)) return failure(); auto srcMemrefType = srcMemref->getType().cast(); - auto dstTensorType = - extractSliceOp.getResult().getType().cast(); - // Expand offsets, sizes and strides to the full rank to handle the - // rank-reducing case. - SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); - SmallVector mixedSizes = extractSliceOp.getMixedSizes(); - SmallVector mixedStrides = extractSliceOp.getMixedStrides(); - OffsetSizeAndStrideOpInterface::expandToRank( - *srcMemref, mixedOffsets, mixedSizes, mixedStrides, - [&](Value target, int64_t dim) -> OpFoldResult { - auto shapedType = target.getType().cast(); - if (shapedType.isDynamicDim(dim)) - return rewriter.create(loc, target, dim).result(); - return rewriter.getIndexAttr(shapedType.getDimSize(dim)); - }); - // Bufferize to subview. - auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( - dstTensorType.getRank(), srcMemrefType, - mixedOffsets, mixedSizes, mixedStrides) - .cast(); + // Take a subview of the source buffer. + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets, + mixedSizes, mixedStrides) + .cast(); Value subView = rewriter.create( loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, mixedStrides); @@ -690,30 +678,22 @@ // catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); + SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); + SmallVector mixedSizes = insertSliceOp.getMixedSizes(); + SmallVector mixedStrides = insertSliceOp.getMixedStrides(); Location loc = insertSliceOp.getLoc(); + + // Get destination buffer. FailureOr dstMemref = getBuffer(rewriter, insertSliceOp.getDest(), options); if (failed(dstMemref)) return failure(); - // Expand offsets, sizes and strides to the full rank to handle the - // rank-reducing case. - SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); - SmallVector mixedSizes = insertSliceOp.getMixedSizes(); - SmallVector mixedStrides = insertSliceOp.getMixedStrides(); - OffsetSizeAndStrideOpInterface::expandToRank( - *dstMemref, mixedOffsets, mixedSizes, mixedStrides, - [&](Value target, int64_t dim) -> OpFoldResult { - auto shapedType = target.getType().cast(); - if (shapedType.isDynamicDim(dim)) - return rewriter.create(loc, target, dim).result(); - return rewriter.getIndexAttr(shapedType.getDimSize(dim)); - }); - // Take a subview of the dst. + // Take a subview of the destination buffer. auto dstMemrefType = dstMemref->getType().cast(); auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( - insertSliceOp.getSourceType().getRank(), dstMemrefType, + insertSliceOp.getSourceType().getShape(), dstMemrefType, mixedOffsets, mixedSizes, mixedStrides) .cast(); Value subView = rewriter.create( @@ -946,11 +926,22 @@ getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); if (failed(srcBuffer)) return failure(); + + // Take a subview of the destination buffer. + auto destBufferType = destBuffer->getType().cast(); + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + parallelInsertSliceOp.getSourceType().getShape(), destBufferType, + parallelInsertSliceOp.getMixedOffsets(), + parallelInsertSliceOp.getMixedSizes(), + parallelInsertSliceOp.getMixedStrides()) + .cast(); Value subview = rewriter.create( - parallelInsertSliceOp.getLoc(), *destBuffer, + parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), parallelInsertSliceOp.getMixedStrides()); + // This memcpy will fold away if everything bufferizes in-place. if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(), *srcBuffer, subview))) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -216,8 +216,10 @@ static MemRefType dropUnitDims(MemRefType inputType, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { + SmallVector targetShape = llvm::to_vector( + llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( - 0, inputType, offsets, sizes, strides); + targetShape, inputType, offsets, sizes, strides); return canonicalizeStridedLayout(rankReducedType.cast()); } diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -292,7 +292,7 @@ // CHECK-SAME: %[[t1:.*]]: tensor, %[[t2:.*]]: tensor, // CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index func.func @tensor.insert_slice(%t1: tensor, %t2: tensor, - %idx1: index, %idx2: index) -> tensor { + %idx1: index, %idx2: index) -> tensor { // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref @@ -313,6 +313,40 @@ // ----- +// CHECK: #[[$MAP11:.*]] = affine_map<()[s0] -> (s0)> + +// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_1( +func.func @tensor.insert_slice_rank_reducing_1( + %t1: tensor, %f: tensor, %idx1: index, %idx2: index) + -> tensor +{ + // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref + // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref to memref + // CHECK: memref.copy {{.*}} : memref to memref + %0 = tensor.insert_slice %f into %t1[%idx1, %idx2][1, 1][1, 1] + : tensor into tensor + return %0 : tensor +} + +// ----- + +// CHECK: #[[$MAP12:.*]] = affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5)> + +// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_2( +func.func @tensor.insert_slice_rank_reducing_2( + %t1: tensor, %t2: tensor<2x1x4x1x1xf32>, %i: index) + -> tensor +{ + // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref + // CHECK: memref.subview %[[alloc]][{{.*}}] [1, 2, 1, 4, 1, 1, 1] [1, 1, 1, 1, 1, 1, 1] : memref to memref<2x1x4x1x1xf32, #[[$MAP12]]> + // CHECK: memref.copy {{.*}} : memref<2x1x4x1x1xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]> + %0 = tensor.insert_slice %t2 into %t1[%i, %i, %i, %i, %i, %i, %i][1, 2, 1, 4, 1, 1, 1][1, 1, 1, 1, 1, 1, 1] + : tensor<2x1x4x1x1xf32> into tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @tensor.insert( // CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index, // CHECK-SAME: %[[f:.*]]: f32 diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -193,3 +193,27 @@ } return %5: tensor } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> + +// CHECK-LABEL: func.func @rank_reducing_parallel_insert_slice +func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tensor<200x100xf32>) { + %c1 = arith.constant 1 : index + %num_threads = arith.constant 100 : index + + // CHECK: scf.foreach_thread {{.*}} { + %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<200x100xf32> { + %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> + scf.foreach_thread.perform_concurrently { + // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<100xf32, #[[$MAP0]]> to memref<1xf32, #[[$MAP0]]> + // CHECK: memref.subview %{{.*}}[1, %{{.*}}] [1, 1] [1, 1] : memref<200x100xf32, #[[$MAP1]]> to memref<1xf32, #[[$MAP0]]> + tensor.parallel_insert_slice %1 into %out[1, %thread_idx][1, 1][1, 1] : + tensor<1xf32> into tensor<200x100xf32> + } + } + // CHECK: } + return +} diff --git a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp --- a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp +++ b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp @@ -21,7 +21,7 @@ OpBuilder b(&ctx); auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType()); auto reducedType = SubViewOp::inferRankReducedResultType( - /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1}); + /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1}); AffineExpr dim0; bindDims(&ctx, dim0); auto expectedType = @@ -38,7 +38,7 @@ auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), AffineMap::get(2, 0, 1000 * dim0 + dim1)); auto reducedType = SubViewOp::inferRankReducedResultType( - /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1}); + /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1}); auto expectedType = MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 2003)); EXPECT_EQ(reducedType, expectedType); @@ -52,7 +52,7 @@ auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), AffineMap::get(2, 0, 1000 * dim0 + dim1)); auto reducedType = SubViewOp::inferRankReducedResultType( - /*resultRank=*/0, sourceMemref, {2, 3}, {1, 1}, {1, 1}); + /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1}); auto expectedType = MemRefType::get({}, b.getIndexType(), AffineMap::get(0, 0, b.getAffineConstantExpr(2003)));