diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -756,6 +756,7 @@ let assemblyFormat = [{ $source `:` type($source) `->` type(results) attr-dict }]; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1163,6 +1163,136 @@ return success(); } +//===----------------------------------------------------------------------===// +// ExtractStridedMetadataOp +//===----------------------------------------------------------------------===// +namespace { +/// Replace `baseBuffer, offset, sizes, strides = +/// extract_strided_metadata(subview(memref, subOffset, +/// subSizes, subStrides))` +/// With +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subSizes#i +/// offset = baseOffset + sum(subOffset#i * strides#i) +/// sizes = subSizes +/// \endverbatim +/// +/// In other words, get rid of the subview in that expression and canonicalize +/// on its effects on the offset, the sizes, and the strides. +struct ExtractStridedMetadataOpSubviewFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto subview = op.getSource().getDefiningOp(); + if (!subview) + return failure(); + + // Build a plain extract_strided_metadata(memref) from + // extract_strided_metadata(subview(memref)). + Location origLoc = op.getLoc(); + IndexType indexType = rewriter.getIndexType(); + Value source = subview.getSource(); + auto sourceType = source.getType().cast(); + int sourceRank = sourceType.getRank(); + SmallVector sizeStrideTypes; + for (int i = 0; i < sourceRank; ++i) + sizeStrideTypes.push_back(indexType); + + auto newExtractStridedMetadata = rewriter.create( + origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, + sizeStrideTypes, source); + + // Compute the new strides from the base strides: + // newStride#i = baseStride#i * subStride#i + SmallVector strides; + // Since subviews can drop some dimensions, some strides may not end + // up in the final value that are replacing. + // Store the strides that will actually be used for that replacement + // separately. + SmallVector stridesForFinalType; + llvm::SmallBitVector droppedDims = subview.getDroppedDims(); + for (int i = 0; i < sourceRank; ++i) { + Value subStride; + if (subview.isDynamicStride(i)) + subStride = subview.getDynamicStride(i); + else + subStride = rewriter.create( + origLoc, subview.getStaticStride(i)); + + Value newStride = rewriter.createOrFold( + origLoc, newExtractStridedMetadata.getStrides()[i], subStride); + + strides.push_back(newStride); + + if (!droppedDims.test(i)) + stridesForFinalType.push_back(newStride); + } + + // Compute the offset: + // offset = baseOffset + sum(subOffsets#i * newStrides#i) + Value offset = newExtractStridedMetadata.getOffset(); + for (int i = 0; i < sourceRank; ++i) { + Value subOffset; + if (subview.isDynamicOffset(i)) { + subOffset = subview.getDynamicOffset(i); + } else { + int64_t offsetVal = subview.getStaticOffset(i); + if (offsetVal == 0) { + // Will end up adding 0, skip. + continue; + } + subOffset = rewriter.create(origLoc, offsetVal); + } + + Value scaledSubOffset = + rewriter.createOrFold(origLoc, subOffset, strides[i]); + offset = rewriter.createOrFold(origLoc, offset, + scaledSubOffset); + } + + SmallVector results; + // The final result is . + // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all + // the values. + auto subType = subview.getType().cast(); + results.reserve(subType.getRank() * 2 + 2); + + results.push_back(newExtractStridedMetadata.getBaseBuffer()); + results.push_back(offset); + + // The sizes of the final type are defined directly by the input sizes of + // the subview. + for (int i = 0; i < sourceRank; ++i) { + if (droppedDims.test(i)) + continue; + Value currentSize; + if (subview.isDynamicSize(i)) + currentSize = subview.getDynamicSize(i); + else + currentSize = rewriter.create( + origLoc, subview.getStaticSize(i)); + results.push_back(currentSize); + } + + results.append(stridesForFinalType); + + rewriter.replaceOp(op, results); + return success(); + } +}; +} // namespace + +void ExtractStridedMetadataOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -812,3 +812,171 @@ // CHECK-SAME: %[[ARG1:.+]]: index // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0] [1, %[[ARG1]]] [1, 1] // CHECK-SAME: memref<8x?xf32> to memref + +// ----- + +// Check that we canonicalize extract_strided_metadata of subview to +// base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata +// strides = base_stride_i * subview_stride_i +// offset = base_offset + sum(subview_offsets_i * strides_i). +// +// This test also checks that we don't create useless arith operations +// when subview_offsets_i is 0. +// +// CHECK-LABEL: func @extract_strided_metadata_of_subview +// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32>) +// +// Materialize the offset for dimension 1. +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// +// Plain extract_strided_metadata. +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// Compute the offset for dimension 1. +// (The stride is not materialized because it is base_strides#1 * 1) +// CHECK: %[[DIM1_OFFSET:.*]] = arith.muli %[[STRIDES]]#1, %[[C2]] +// +// Add the base offset. +// CHECK: %[[FINAL_OFFSET:.*]] = arith.addi %[[OFFSET]], %[[DIM1_OFFSET]] +// +// Return the new tuple. +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C2]], %[[C2]], %[[STRIDES]]#0, %[[STRIDES]]#1 +func.func @extract_strided_metadata_of_subview(%base: memref<5x4xf32>) -> (memref, index, index, index, index, index) { + %subview = memref.subview %base[0, 2][2, 2][1, 1] : memref<5x4xf32> to memref<2x2xf32, strided<[4, 1], offset: 2>> + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : memref<2x2xf32, strided<[4,1], offset:2>> -> memref, index, index, index, index, index + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref, index, index, index, index, index +} + +// ----- + +// Check that we canonicalize extract_strided_metadata of subview properly +// when dynamic sizes are involved. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// +// CHECK-LABEL: func @extract_strided_metadata_of_subview_with_dynamic_size +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, +// CHECK-SAME: %[[DYN_SIZE:.*]]: index) +// +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[DIM0_OFFSET:.*]] = arith.muli %[[STRIDES]]#0, %[[C3]] +// CHECK-DAG: %[[DIM1_OFFSET:.*]] = arith.muli %[[STRIDES]]#1, %[[C4]] +// CHECK-DAG: %[[DIM2_OFFSET:.*]] = arith.muli %[[STRIDES]]#2, %[[C2]] +// +// CHECK-DAG: %[[PARTIAL_OFFSET0:.*]] = arith.addi %[[OFFSET]], %[[DIM0_OFFSET]] +// CHECK-DAG: %[[PARTIAL_OFFSET1:.*]] = arith.addi %[[PARTIAL_OFFSET0]], %[[DIM1_OFFSET]] +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = arith.addi %[[PARTIAL_OFFSET1]], %[[DIM2_OFFSET]] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[STRIDES]]#0, %[[STRIDES]]#1, %[[STRIDES]]#2 +func.func @extract_strided_metadata_of_subview_with_dynamic_size(%base: memref<8x16x4xf32>, %size: index) -> (memref, index, index, index, index, index, index, index) { + %subview = memref.subview %base[3, 4, 2][%size, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref> + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : memref> -> memref, index, index, index, index, index, index, index + return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : memref, index, index, index, index, index, index, index +} + +// ----- + +// Check that we canonicalize extract_strided_metadata of subview properly +// when the subview reduces the ranks. +// In particular the returned strides must come from #1 and #2 of the %strides +// value of the new extract_strided_metadata_of_subview, not #0 and #1. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// +// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>) +// +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[DIM0_OFFSET:.*]] = arith.muli %[[STRIDES]]#0, %[[C3]] +// CHECK-DAG: %[[DIM1_OFFSET:.*]] = arith.muli %[[STRIDES]]#1, %[[C4]] +// CHECK-DAG: %[[DIM2_OFFSET:.*]] = arith.muli %[[STRIDES]]#2, %[[C2]] +// +// CHECK-DAG: %[[PARTIAL_OFFSET0:.*]] = arith.addi %[[OFFSET]], %[[DIM0_OFFSET]] +// CHECK-DAG: %[[PARTIAL_OFFSET1:.*]] = arith.addi %[[PARTIAL_OFFSET0]], %[[DIM1_OFFSET]] +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = arith.addi %[[PARTIAL_OFFSET1]], %[[DIM2_OFFSET]] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C6]], %[[C3]], %[[STRIDES]]#1, %[[STRIDES]]#2 +func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4xf32>) -> (memref, index, index, index, index, index) { +%subview = memref.subview %base[3, 4, 2][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : memref<6x3xf32, strided<[4,1], offset: 210>> -> memref, index, index, index, index, index + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref, index, index, index, index, index +} + +// ----- + +// Check that we canonicalize extract_strided_metadata of subview properly +// when the subview reduces the rank and some of the strides are variable. +// In particular, we check that: +// A. The dynamic stride is multiplied with the base stride to create the new +// stride for dimension 1. +// B. The first returned stride is the value computed in #A. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, +// CHECK-SAME: %[[DYN_STRIDE:.*]]: index) +// +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[DIM1_STRIDE:.*]] = arith.muli %[[STRIDES]]#1, %[[DYN_STRIDE]] +// +// CHECK-DAG: %[[DIM0_OFFSET:.*]] = arith.muli %[[STRIDES]]#0, %[[C3]] +// CHECK-DAG: %[[DIM1_OFFSET:.*]] = arith.muli %[[DIM1_STRIDE]], %[[C4]] +// CHECK-DAG: %[[DIM2_OFFSET:.*]] = arith.muli %[[STRIDES]]#2, %[[C2]] +// +// CHECK-DAG: %[[PARTIAL_OFFSET0:.*]] = arith.addi %[[OFFSET]], %[[DIM0_OFFSET]] +// CHECK-DAG: %[[PARTIAL_OFFSET1:.*]] = arith.addi %[[PARTIAL_OFFSET0]], %[[DIM1_OFFSET]] +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = arith.addi %[[PARTIAL_OFFSET1]], %[[DIM2_OFFSET]] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[STRIDES]]#2 +func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides(%base: memref<8x16x4xf32>, %stride: index) -> (memref, index, index, index, index, index) { + %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : memref<6x3xf32, strided<[4, 1], offset: 210>> -> memref, index, index, index, index, index + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref, index, index, index, index, index +} + +// ----- + +// Check that we canonicalize extract_strided_metadata of subview properly +// when the subview uses variable offsets. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// +// CHECK-LABEL: func @extract_strided_metadata_of_subview_w_variable_offset +// CHECK-SAME: (%[[ARG:.*]]: memref<384x128xf32>, +// CHECK-SAME: %[[DYN_OFFSET0:.*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:.*]]: index) +// +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[DIM0_OFFSET:.*]] = arith.muli %[[DYN_OFFSET0]], %[[STRIDES]]#0 +// CHECK-DAG: %[[DIM1_OFFSET:.*]] = arith.muli %[[DYN_OFFSET1]], %[[STRIDES]]#1 +// +// CHECK-DAG: %[[PARTIAL_OFFSET0:.*]] = arith.addi %[[OFFSET]], %[[DIM0_OFFSET]] +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = arith.addi %[[PARTIAL_OFFSET0]], %[[DIM1_OFFSET]] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C64]], %[[C64]], %[[STRIDES]]#0, %[[STRIDES]]#1 +#map0 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)> +func.func @extract_strided_metadata_of_subview_w_variable_offset(%arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index) -> (memref, index, index, index, index, index) { + %subview = memref.subview %arg0[%arg1, %arg2] [64, 64] [1, 1] : memref<384x128xf32> to memref<64x64xf32, #map0> + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : memref<64x64xf32, #map0> -> memref, index, index, index, index, index + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref, index, index, index, index, index +}