diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h @@ -14,6 +14,37 @@ namespace mlir { namespace tensor { +/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use +/// when combining a producer slice **into** a consumer slice. +/// +/// This function performs the following computation: +/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets +/// - Combined sizes = consumer_sizes +/// - Combined strides = producer_strides * consumer_strides +LogicalResult +mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, + ArrayRef producerOffsets, + ArrayRef producerSizes, + ArrayRef producerStrides, + const llvm::SmallBitVector &droppedProducerDims, + ArrayRef consumerOffsets, + ArrayRef consumerSizes, + ArrayRef consumerStrides, + SmallVector &combinedOffsets, + SmallVector &combinedSizes, + SmallVector &combinedStrides); + +/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use +/// when combining a `producer` slice op **into** a `consumer` slice op. +LogicalResult +mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, + OffsetSizeAndStrideOpInterface producer, + OffsetSizeAndStrideOpInterface consumer, + const llvm::SmallBitVector &droppedProducerDims, + SmallVector &combinedOffsets, + SmallVector &combinedSizes, + SmallVector &combinedStrides); + //===----------------------------------------------------------------------===// // Extract slice from `tensor.collapse_shape` //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -17,29 +17,101 @@ using namespace mlir; using namespace mlir::tensor; -/// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and -/// returns the results. -static SmallVector mergeOffsets(Location loc, - ArrayRef offsets1, - ArrayRef offsets2, - OpBuilder &builder) { - SmallVector foldedOffsets; - assert(offsets1.size() == offsets2.size()); - foldedOffsets.reserve(offsets1.size()); - - AffineExpr dim1, dim2; - bindDims(builder.getContext(), dim1, dim2); - - for (const auto &pair : llvm::zip(offsets1, offsets2)) { - auto offset0 = - getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair)); - auto offset1 = - getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair)); - auto foldedOffset = - makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1}); - foldedOffsets.push_back(foldedOffset.getResult()); +/// Creates AffineExpr from `ofr`: if the OpFoldResult is a Value, creates a +/// AffineSymbolExpr and appends it to `symbols`; otherwise creates a +/// AffineConstantExpr. +static AffineExpr getAffineExpr(OpFoldResult ofr, + SmallVector &symbols) { + if (auto attr = ofr.dyn_cast()) { + return getAffineConstantExpr(attr.cast().getInt(), + attr.getContext()); } - return foldedOffsets; + Value v = ofr.get(); + AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext()); + symbols.push_back(v); + return expr; +} + +/// Builds the AffineExpr incrementally for arithmetic operations. +static AffineExpr add(AffineExpr expr, OpFoldResult ofr, + SmallVector &symbols) { + return expr + getAffineExpr(ofr, symbols); +} +static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs, + SmallVector &symbols) { + return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols); +} + +/// Converts an AffineExpr to OpFoldResult by generating an `affine.apply` +/// op and fold it. +static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc, + AffineExpr expr, + SmallVector &symbols) { + AffineMap m = AffineMap::get(0, symbols.size(), expr); + return makeComposedFoldedAffineApply(builder, loc, m, symbols); +} + +LogicalResult tensor::mergeOffsetsSizesAndStrides( + OpBuilder &builder, Location loc, ArrayRef producerOffsets, + ArrayRef producerSizes, + ArrayRef producerStrides, + const llvm::SmallBitVector &droppedProducerDims, + ArrayRef consumerOffsets, + ArrayRef consumerSizes, + ArrayRef consumerStrides, + SmallVector &combinedOffsets, + SmallVector &combinedSizes, + SmallVector &combinedStrides) { + combinedOffsets.resize(producerOffsets.size()); + combinedSizes.resize(producerOffsets.size()); + combinedStrides.resize(producerOffsets.size()); + unsigned consumerPos = 0; + for (auto i : llvm::seq(0, producerOffsets.size())) { + if (droppedProducerDims.test(i)) { + // For dropped dims, get the values from the producer. + combinedOffsets[i] = producerOffsets[i]; + combinedSizes[i] = producerSizes[i]; + combinedStrides[i] = producerStrides[i]; + continue; + } + SmallVector offsetSymbols, strideSymbols; + // The combined offset is computed as + // producer_offset + consumer_offset * producer_strides. + combinedOffsets[i] = + getOpFoldResult(builder, loc, + add(mul(consumerOffsets[consumerPos], + producerStrides[i], offsetSymbols), + producerOffsets[i], offsetSymbols), + offsetSymbols); + combinedSizes[i] = consumerSizes[consumerPos]; + // The combined stride is computed as + // consumer_stride * producer_stride. + combinedStrides[i] = getOpFoldResult( + builder, loc, + mul(consumerStrides[consumerPos], producerStrides[i], strideSymbols), + strideSymbols); + consumerPos++; + } + return success(); +} + +LogicalResult tensor::mergeOffsetsSizesAndStrides( + OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer, + OffsetSizeAndStrideOpInterface consumer, + const llvm::SmallBitVector &droppedProducerDims, + SmallVector &combinedOffsets, + SmallVector &combinedSizes, + SmallVector &combinedStrides) { + SmallVector consumerOffsets = consumer.getMixedOffsets(); + SmallVector consumerSizes = consumer.getMixedSizes(); + SmallVector consumerStrides = consumer.getMixedStrides(); + SmallVector producerOffsets = producer.getMixedOffsets(); + SmallVector producerSizes = producer.getMixedSizes(); + SmallVector producerStrides = producer.getMixedStrides(); + return tensor::mergeOffsetsSizesAndStrides( + builder, loc, producerOffsets, producerSizes, producerStrides, + droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides, + combinedOffsets, combinedSizes, combinedStrides); } namespace { @@ -53,24 +125,15 @@ if (!prevOp) return failure(); - if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride()) + SmallVector newOffsets, newSizes, newStrides; + if (failed(mergeOffsetsSizesAndStrides(rewriter, nextOp.getLoc(), prevOp, + nextOp, prevOp.getDroppedDims(), + newOffsets, newSizes, newStrides))) return failure(); - auto prevResultType = prevOp.getType().cast(); - if (prevOp.getSourceType().getRank() != prevResultType.getRank()) - return rewriter.notifyMatchFailure( - prevOp, "rank-reducing producder case unimplemented"); - - Location loc = nextOp.getLoc(); - - SmallVector prevOffsets = prevOp.getMixedOffsets(); - SmallVector nextOffsets = nextOp.getMixedOffsets(); - SmallVector foldedOffsets = - mergeOffsets(loc, prevOffsets, nextOffsets, rewriter); - - rewriter.replaceOpWithNewOp( - nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets, - nextOp.getMixedSizes(), nextOp.getMixedStrides()); + rewriter.replaceOpWithNewOp(nextOp, nextOp.getType(), + prevOp.getSource(), newOffsets, + newSizes, newStrides); return success(); } }; diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir --- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir +++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir @@ -9,10 +9,12 @@ // CHECK-LABEL: func.func @extract_slice_same_rank // CHECK-SAME: (%[[SOURCE:.+]]: tensor, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index) -// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET0]], %[[OFFSET1]]] +// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]] // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1] // CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32> +// ----- + func.func @extract_slice_rank_reducing_consumer( %src: tensor, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> { %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor to tensor<128x128x128x?xf32> @@ -23,6 +25,8 @@ // CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer // CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor to tensor<16x?xf32> +// ----- + func.func @extract_slice_rank_reducing_producer( %src: tensor, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> { %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor to tensor<128x?xf32> @@ -30,8 +34,27 @@ return %1: tensor<8x?xf32> } -// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer -// CHECK-COUNT-2: tensor.extract_slice +// CHECK-LABEL: func.func @extract_slice_rank_reducing_producer +// CHECK-SAME: (%[[SRC:.+]]: tensor, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index) +// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][0, 8, 2, %[[OFFSET]]] [1, 8, 1, %[[SIZE1]]] [1, 1, 1, 1] : tensor to tensor<8x?xf32> +// CHECK: return %[[EXTRACT]] : tensor<8x?xf32> + +// ----- + +func.func @extract_slice_non_one_stride( + %src: tensor, %offset0: index, %offset1: index, %size0: index, %size1: index, %stride0: index, %stride1: index) -> tensor { + %0 = tensor.extract_slice %src[%offset0] [%size0] [%stride0] : tensor to tensor + %1 = tensor.extract_slice %0[%offset1] [%size1] [%stride1] : tensor to tensor + return %1: tensor +} + +// CHECK-LABEL: func.func @extract_slice_non_one_stride +// CHECK-SAME: (%[[SRC:.+]]: tensor, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index, %[[STRIDE0:.+]]: index, %[[STRIDE1:.+]]: index) +// CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>()[%[[OFFSET1]], %[[STRIDE0]], %[[OFFSET0]]] +// CHECK: %[[STRIDE:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%[[STRIDE1]], %[[STRIDE0]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][%[[OFFSET]]] [%[[SIZE1]]] [%[[STRIDE]]] : tensor to tensor +// CHECK: return %[[EXTRACT]] : tensor // ----- @@ -47,6 +70,8 @@ // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1] // CHECK: return %[[INSERT]] +// ----- + func.func @insert_slice_rank_reducing_dynamic_shape( %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor, %offset: index, %size: index) -> tensor<128x128x128x128xf32> { %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor into tensor<1x?x1xf32>