diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -487,14 +487,103 @@ return success(atLeastOneReplacement); } }; + +/// Replace `base, offset, sizes, strides = +/// extract_strided_metadata(allocLikeOp)` +/// +/// With +/// +/// ``` +/// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref +/// offset = 0 +/// sizes = allocSizes +/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) +/// ``` +/// +/// The transformation only applies if the allocLikeOp has been normalized. +/// In other words, the affine_map must be an identity. +template +struct ExtractStridedMetadataOpAllocFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto allocLikeOp = op.getSource().getDefiningOp(); + if (!allocLikeOp) + return failure(); + + auto memRefType = + allocLikeOp.getResult().getType().template cast(); + if (!memRefType.getLayout().isIdentity()) + return rewriter.notifyMatchFailure( + allocLikeOp, "alloc-like operations should have been normalized"); + + Location loc = op.getLoc(); + int rank = memRefType.getRank(); + + // Collect the sizes. + ValueRange dynamic = allocLikeOp.getDynamicSizes(); + SmallVector sizes; + sizes.reserve(rank); + unsigned dynamicPos = 0; + for (int64_t size : memRefType.getShape()) { + if (ShapedType::isDynamic(size)) + sizes.push_back(dynamic[dynamicPos++]); + else + sizes.push_back(rewriter.getIndexAttr(size)); + } + + // Strides (just creates identity strides). + SmallVector strides(rank, rewriter.getIndexAttr(1)); + AffineExpr expr = rewriter.getAffineConstantExpr(1); + unsigned symbolNumber = 0; + for (int i = rank - 2; i >= 0; --i) { + expr = expr * rewriter.getAffineSymbolExpr(symbolNumber++); + assert(i + 1 + symbolNumber == sizes.size() && + "The ArrayRef should encompass the last #symbolNumber sizes"); + ArrayRef sizesInvolvedInStride(&sizes[i + 1], symbolNumber); + strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr, + sizesInvolvedInStride); + } + + // Put all the values together to replace the results. + SmallVector results; + results.reserve(rank * 2 + 2); + + auto baseBufferType = op.getBaseBuffer().getType().cast(); + int64_t offset = 0; + if (allocLikeOp.getType() == baseBufferType) + results.push_back(allocLikeOp); + else + results.push_back(rewriter.create( + loc, baseBufferType, allocLikeOp, offset, + /*sizes=*/ArrayRef(), + /*strides=*/ArrayRef())); + + // Offset. + results.push_back(rewriter.create(loc, offset)); + + for (OpFoldResult size : sizes) + results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); + + for (OpFoldResult stride : strides) + results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride)); + + rewriter.replaceOp(op, results); + return success(); + } +}; } // namespace void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { - patterns - .add( - patterns.getContext()); + patterns.add, + ExtractStridedMetadataOpAllocFolder>( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -527,3 +527,226 @@ index, index, index, index, index, index, index, index, index, index } + +// ----- + +// Check that we simplify extract_strided_metadata(alloc) +// into simply the alloc with the information extracted from +// the memref type and arguments of the alloc. +// +// baseBuffer = reinterpret_cast alloc +// offset = 0 +// sizes = shape(memref) +// strides = strides(memref) +// +// For dynamic shapes, we simply use the values that feed the alloc. +// +// Simple rank 0 test: we don't need a reinterpret_cast here. +// CHECK-LABEL: func @extract_strided_metadata_of_alloc_all_static_0_rank +// +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() +// CHECK: return %[[ALLOC]], %[[C0]] : memref, index +func.func @extract_strided_metadata_of_alloc_all_static_0_rank() + -> (memref, index) { + + %A = memref.alloc() : memref + %base, %offset = memref.extract_strided_metadata %A : + memref + -> memref, index + + return %base, %offset : + memref, index +} + +// ----- + +// Simplification of extract_strided_metadata(alloc). +// Check that we properly use the dynamic sizes to +// create the new sizes and strides. +// size 0 = dyn_size0 +// size 1 = 4 +// size 2 = dyn_size2 +// size 3 = dyn_size3 +// +// stride 0 = size 1 * size 2 * size 3 +// = 4 * dyn_size2 * dyn_size3 +// stride 1 = size 2 * size 3 +// = dyn_size2 * dyn_size3 +// stride 2 = size 3 +// = dyn_size3 +// stride 3 = 1 +// +// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> +// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-LABEL: extract_strided_metadata_of_alloc_dyn_size +// CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_SIZE3:.*]]: index) +// +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[DYN_SIZE0]], %[[DYN_SIZE2]], %[[DYN_SIZE3]]) +// +// CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]] +// CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE2]], %[[DYN_SIZE3]]] +// +// CHECK-DAG: %[[CASTED_ALLOC:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref to memref +// +// CHECK: return %[[CASTED_ALLOC]], %[[C0]], %[[DYN_SIZE0]], %[[C4]], %[[DYN_SIZE2]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]] +func.func @extract_strided_metadata_of_alloc_dyn_size( + %dyn_size0 : index, %dyn_size2 : index, %dyn_size3 : index) + -> (memref, index, + index, index, index, index, + index, index, index, index) { + + %A = memref.alloc(%dyn_size0, %dyn_size2, %dyn_size3) : memref + + %base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A : + memref + -> memref, index, + index, index, index, index, + index, index, index, index + + return %base, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, + %strides#0, %strides#1, %strides#2, %strides#3 : + memref, index, + index, index, index, index, + index, index, index, index +} + +// ----- + +// Same check as extract_strided_metadata_of_alloc_dyn_size but alloca +// instead of alloc. Just to make sure we handle allocas the same way +// we do with alloc. +// While at it, test a slightly different shape than +// extract_strided_metadata_of_alloc_dyn_size. +// +// size 0 = dyn_size0 +// size 1 = dyn_size1 +// size 2 = 4 +// size 3 = dyn_size3 +// +// stride 0 = size 1 * size 2 * size 3 +// = dyn_size1 * 4 * dyn_size3 +// stride 1 = size 2 * size 3 +// = 4 * dyn_size3 +// stride 2 = size 3 +// = dyn_size3 +// stride 3 = 1 +// +// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> +// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-LABEL: extract_strided_metadata_of_alloca_dyn_size +// CHECK-SAME: (%[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE3:.*]]: index) +// +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca(%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE3]]) +// +// CHECK-DAG: %[[STRIDE0:.*]] = affine.apply #[[$STRIDE0_MAP]]()[%[[DYN_SIZE1]], %[[DYN_SIZE3]]] +// CHECK-DAG: %[[STRIDE1:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_SIZE3]]] +// +// CHECK-DAG: %[[CASTED_ALLOCA:.*]] = memref.reinterpret_cast %[[ALLOCA]] to offset: [0], sizes: [], strides: [] : memref to memref +// +// CHECK: return %[[CASTED_ALLOCA]], %[[C0]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[C4]], %[[DYN_SIZE3]], %[[STRIDE0]], %[[STRIDE1]], %[[DYN_SIZE3]], %[[C1]] +func.func @extract_strided_metadata_of_alloca_dyn_size( + %dyn_size0 : index, %dyn_size1 : index, %dyn_size3 : index) + -> (memref, index, + index, index, index, index, + index, index, index, index) { + + %A = memref.alloca(%dyn_size0, %dyn_size1, %dyn_size3) : memref + + %base, %offset, %sizes:4, %strides:4 = memref.extract_strided_metadata %A : + memref + -> memref, index, + index, index, index, index, + index, index, index, index + + return %base, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, + %strides#0, %strides#1, %strides#2, %strides#3 : + memref, index, + index, index, index, index, + index, index, index, index +} + +// ----- + +// The following few alloc tests are negative tests (the simplification +// doesn't happen) to make sure non trivial memref types are treated +// as "not been normalized". +// CHECK-LABEL: extract_strided_metadata_of_alloc_with_variable_offset +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]] +// CHECK: return %[[BASE]] +#map0 = affine_map<(d0)[s0] -> (d0 + s0)> +func.func @extract_strided_metadata_of_alloc_with_variable_offset(%arg : index) + -> (memref, index, index, index) { + + %A = memref.alloc()[%arg] : memref<4xi16, #map0> + %base, %offset, %size, %stride = memref.extract_strided_metadata %A : + memref<4xi16, #map0> + -> memref, index, index, index + + return %base, %offset, %size, %stride : + memref, index, index, index +} + +// ----- + +// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]] +// CHECK: return %[[BASE]] +#map0 = affine_map<(d0) -> (d0 + 12)> +func.func @extract_strided_metadata_of_alloc_with_cst_offset(%arg : index) + -> (memref, index, index, index) { + + %A = memref.alloc() : memref<4xi16, #map0> + %base, %offset, %size, %stride = memref.extract_strided_metadata %A : + memref<4xi16, #map0> + -> memref, index, index, index + + return %base, %offset, %size, %stride : + memref, index, index, index +} + +// ----- + +// CHECK-LABEL: extract_strided_metadata_of_alloc_with_cst_offset_in_type +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]] +// CHECK: return %[[BASE]] +func.func @extract_strided_metadata_of_alloc_with_cst_offset_in_type(%arg : index) + -> (memref, index, index, index) { + + %A = memref.alloc() : memref<4xi16, strided<[1], offset : 10>> + %base, %offset, %size, %stride = memref.extract_strided_metadata %A : + memref<4xi16, strided<[1], offset : 10>> + -> memref, index, index, index + + return %base, %offset, %size, %stride : + memref, index, index, index +} + +// ----- + +// CHECK-LABEL: extract_strided_metadata_of_alloc_with_strided +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK: %[[BASE:[^,]*]], {{.*}} = memref.extract_strided_metadata %[[ALLOC]] +// CHECK: return %[[BASE]] +func.func @extract_strided_metadata_of_alloc_with_strided(%arg : index) + -> (memref, index, index, index) { + + %A = memref.alloc() : memref<4xi16, strided<[12]>> + %base, %offset, %size, %stride = memref.extract_strided_metadata %A : + memref<4xi16, strided<[12]>> + -> memref, index, index, index + + return %base, %offset, %size, %stride : + memref, index, index, index +}