diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -53,6 +53,10 @@ /// (sizes, offset, strides) of a memref into easier to analyze constructs. void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns); +/// Appends patterns for resolving `memref.extract_strided_metadata` into +/// `memref.extract_strided_metadata` of its source. +void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns); + /// Appends patterns for emulating wide integer memref operations with ops over /// narrower integer types. void populateMemRefWideIntEmulationPatterns( diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -23,6 +23,7 @@ MLIRArithDialect MLIRArithTransforms MLIRBufferizationDialect + MLIRDialectUtils MLIRFuncDialect MLIRGPUOps MLIRInferTypeOpInterface diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -34,6 +35,115 @@ using namespace mlir; namespace { + +struct StridedMetadata { + Value basePtr; + OpFoldResult offset; + SmallVector sizes; + SmallVector strides; +}; + +/// From `subview(memref, subOffset, subSizes, subStrides))` compute +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subSizes#i +/// offset = baseOffset + sum(subOffset#i * baseStrides#i) +/// sizes = subSizes +/// \endverbatim +/// +/// and return {baseBuffer, offset, sizes, strides} +static FailureOr +resolveSubviewStridedMetadata(RewriterBase &rewriter, + memref::SubViewOp subview) { + // Build a plain extract_strided_metadata(memref) from subview(memref). + Location origLoc = subview.getLoc(); + Value source = subview.getSource(); + auto sourceType = source.getType().cast(); + unsigned sourceRank = sourceType.getRank(); + + auto newExtractStridedMetadata = + rewriter.create(origLoc, source); + + auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType); + + // Compute the new strides and offset from the base strides and offset: + // newStride#i = baseStride#i * subStride#i + // offset = baseOffset + sum(subOffsets#i * newStrides#i) + SmallVector strides; + SmallVector subStrides = subview.getMixedStrides(); + auto origStrides = newExtractStridedMetadata.getStrides(); + + // Hold the affine symbols and values for the computation of the offset. + SmallVector values(2 * sourceRank + 1); + SmallVector symbols(2 * sourceRank + 1); + + bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); + AffineExpr expr = symbols.front(); + values[0] = ShapedType::isDynamic(sourceOffset) + ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) + : rewriter.getIndexAttr(sourceOffset); + SmallVector subOffsets = subview.getMixedOffsets(); + + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + for (unsigned i = 0; i < sourceRank; ++i) { + // Compute the stride. + OpFoldResult origStride = + ShapedType::isDynamic(sourceStrides[i]) + ? origStrides[i] + : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); + strides.push_back(makeComposedFoldedAffineApply( + rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); + + // Build up the computation of the offset. + unsigned baseIdxForDim = 1 + 2 * i; + unsigned subOffsetForDim = baseIdxForDim; + unsigned origStrideForDim = baseIdxForDim + 1; + expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; + values[subOffsetForDim] = subOffsets[i]; + values[origStrideForDim] = origStride; + } + + // Compute the offset. + OpFoldResult finalOffset = + makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); + + // The final result is . + // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all + // the values. + auto subType = subview.getType().cast(); + unsigned subRank = subType.getRank(); + + // The sizes of the final type are defined directly by the input sizes of + // the subview. + // Moreover subviews can drop some dimensions, some strides and sizes may + // not end up in the final value that we are + // replacing. + // Do the filtering here. + SmallVector subSizes = subview.getMixedSizes(); + llvm::SmallBitVector droppedDims = subview.getDroppedDims(); + + SmallVector finalSizes; + finalSizes.reserve(subRank); + + SmallVector finalStrides; + finalStrides.reserve(subRank); + + for (unsigned i = 0; i < sourceRank; ++i) { + if (droppedDims.test(i)) + continue; + + finalSizes.push_back(subSizes[i]); + finalStrides.push_back(strides[i]); + } + assert(finalSizes.size() == subRank && + "Should have populated all the values at this point"); + return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset, + finalSizes, finalStrides}; +} + /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))` /// With /// @@ -54,96 +164,62 @@ LogicalResult matchAndRewrite(memref::SubViewOp subview, PatternRewriter &rewriter) const override { - // Build a plain extract_strided_metadata(memref) from subview(memref). - Location origLoc = subview.getLoc(); - Value source = subview.getSource(); - auto sourceType = source.getType().cast(); - unsigned sourceRank = sourceType.getRank(); - - auto newExtractStridedMetadata = - rewriter.create(origLoc, source); - - auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType); - - // Compute the new strides and offset from the base strides and offset: - // newStride#i = baseStride#i * subStride#i - // offset = baseOffset + sum(subOffsets#i * newStrides#i) - SmallVector strides; - SmallVector subStrides = subview.getMixedStrides(); - auto origStrides = newExtractStridedMetadata.getStrides(); - - // Hold the affine symbols and values for the computation of the offset. - SmallVector values(2 * sourceRank + 1); - SmallVector symbols(2 * sourceRank + 1); - - bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); - AffineExpr expr = symbols.front(); - values[0] = ShapedType::isDynamic(sourceOffset) - ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) - : rewriter.getIndexAttr(sourceOffset); - SmallVector subOffsets = subview.getMixedOffsets(); - - AffineExpr s0 = rewriter.getAffineSymbolExpr(0); - AffineExpr s1 = rewriter.getAffineSymbolExpr(1); - for (unsigned i = 0; i < sourceRank; ++i) { - // Compute the stride. - OpFoldResult origStride = - ShapedType::isDynamic(sourceStrides[i]) - ? origStrides[i] - : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); - strides.push_back(makeComposedFoldedAffineApply( - rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); - - // Build up the computation of the offset. - unsigned baseIdxForDim = 1 + 2 * i; - unsigned subOffsetForDim = baseIdxForDim; - unsigned origStrideForDim = baseIdxForDim + 1; - expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; - values[subOffsetForDim] = subOffsets[i]; - values[origStrideForDim] = origStride; + FailureOr stridedMetadata = + resolveSubviewStridedMetadata(rewriter, subview); + if (failed(stridedMetadata)) { + return rewriter.notifyMatchFailure(subview, + "failed to resolve subview metadata"); } - // Compute the offset. - OpFoldResult finalOffset = - makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); - - // The final result is . - // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all - // the values. - auto subType = subview.getType().cast(); - unsigned subRank = subType.getRank(); - - // The sizes of the final type are defined directly by the input sizes of - // the subview. - // Moreover subviews can drop some dimensions, some strides and sizes may - // not end up in the final value that we are - // replacing. - // Do the filtering here. - SmallVector subSizes = subview.getMixedSizes(); - llvm::SmallBitVector droppedDims = subview.getDroppedDims(); - - SmallVector finalSizes; - finalSizes.reserve(subRank); + rewriter.replaceOpWithNewOp( + subview, subview.getType(), stridedMetadata->basePtr, + stridedMetadata->offset, stridedMetadata->sizes, + stridedMetadata->strides); + return success(); + } +}; - SmallVector finalStrides; - finalStrides.reserve(subRank); +/// Pattern to replace `extract_strided_metadata(subview)` +/// With +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subSizes#i +/// offset = baseOffset + sum(subOffset#i * baseStrides#i) +/// sizes = subSizes +/// \verbatim +/// +/// with `baseBuffer`, `offset`, `sizes` and `strides` being +/// the replacements for the original `extract_strided_metadata`. +struct ExtractStridedMetadataOpSubviewFolder + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - for (unsigned i = 0; i < sourceRank; ++i) { - if (droppedDims.test(i)) - continue; + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto subviewOp = op.getSource().getDefiningOp(); + if (!subviewOp) + return failure(); - finalSizes.push_back(subSizes[i]); - finalStrides.push_back(strides[i]); + FailureOr stridedMetadata = + resolveSubviewStridedMetadata(rewriter, subviewOp); + if (failed(stridedMetadata)) { + return rewriter.notifyMatchFailure( + op, "failed to resolve metadata in terms of source subview op"); } - assert(finalSizes.size() == subRank && - "Should have populated all the values at this point"); + Location loc = subviewOp.getLoc(); + SmallVector results; + results.reserve(subviewOp.getType().getRank() * 2 + 2); + results.push_back(stridedMetadata->basePtr); + results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->offset)); + results.append( + getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); + results.append(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->strides)); + rewriter.replaceOp(op, results); - auto memrefDesc = rewriter.create( - origLoc, subType, newExtractStridedMetadata.getBaseBuffer(), - finalOffset, - /*sizes=*/finalSizes, - /*strides=*/finalStrides); - rewriter.replaceOp(subview, memrefDesc.getResult()); return success(); } }; @@ -634,6 +710,77 @@ } }; +/// Replace `base, offset, sizes, strides = +/// extract_strided_metadata(get_global)` +/// +/// With +/// +/// ``` +/// base = reinterpret_cast get_global to a flat memref +/// offset = 0 +/// sizes = allocSizes +/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) +/// ``` +/// +/// It is expected that the memref.get_global op has static shapes +/// and identity affine_map for the layout. +struct ExtractStridedMetadataOpGetGlobalFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto getGlobalOp = op.getSource().getDefiningOp(); + if (!getGlobalOp) + return failure(); + + auto memRefType = getGlobalOp.getResult().getType().cast(); + if (!memRefType.getLayout().isIdentity()) { + return rewriter.notifyMatchFailure( + getGlobalOp, + "get-global operation result should have been normalized"); + } + + Location loc = op.getLoc(); + int rank = memRefType.getRank(); + + // Collect the sizes. + ArrayRef sizes = memRefType.getShape(); + assert(!llvm::any_of(sizes, ShapedType::isDynamic) && + "unexpected dynamic shape for result of `memref.get_global` op"); + + // Strides (just creates identity strides). + SmallVector strides = computeSuffixProduct(sizes); + + // Put all the values together to replace the results. + SmallVector results; + results.reserve(rank * 2 + 2); + + auto baseBufferType = op.getBaseBuffer().getType().cast(); + int64_t offset = 0; + if (getGlobalOp.getType() == baseBufferType) + results.push_back(getGlobalOp); + else + results.push_back(rewriter.create( + loc, baseBufferType, getGlobalOp, offset, + /*sizes=*/ArrayRef(), + /*strides=*/ArrayRef())); + + // Offset. + results.push_back(rewriter.create(loc, offset)); + + for (auto size : sizes) + results.push_back(rewriter.create(loc, size)); + + for (auto stride : strides) + results.push_back(rewriter.create(loc, stride)); + + rewriter.replaceOp(op, results); + return success(); + } +}; + /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the /// source of the ViewLikeOp. class RewriteExtractAlignedPointerAsIndexOfViewLikeOp @@ -758,6 +905,19 @@ getCollapsedStride>, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, + ExtractStridedMetadataOpGetGlobalFolder, + RewriteExtractAlignedPointerAsIndexOfViewLikeOp, + ExtractStridedMetadataOpReinterpretCastFolder, + ExtractStridedMetadataOpExtractStridedMetadataFolder>( + patterns.getContext()); +} + +void memref::populateResolveExtractStridedMetadataPatterns( + RewritePatternSet &patterns) { + patterns.add, + ExtractStridedMetadataOpAllocFolder, + ExtractStridedMetadataOpGetGlobalFolder, + ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1290,3 +1290,82 @@ index, index, index, index } + +// ----- + +// Check that for `memref.get_global` -> `memref.extract_strided_metadata` resolves +// with the consumer replaced with the strides, sizes and offsets computed from +// `memref.get_global`. Since the result of `memref.get_global is always static shaped +// no need to check for dynamic shapes. + +// CHECK-LABEL: func @extract_strided_metadata_of_get_global() +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C384:.+]] = arith.constant 384 : index +// CHECK-DAG: %[[C512:.+]] = arith.constant 512 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[GET_GLOBAL:.+]] = memref.get_global @const_i32 +// CHECK: %[[CAST:.+]] = memref.reinterpret_cast %[[GET_GLOBAL]] +// CHECK-SAME: offset: [0], sizes: [], strides: [] +// CHECK: return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]] + +memref.global "private" constant @const_i32 : memref<512x384xi32> = dense<42> + +func.func @extract_strided_metadata_of_get_global() + -> (memref, index, index, index, index, index) { + + %A = memref.get_global @const_i32 : memref<512x384xi32> + + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A : + memref<512x384xi32> -> memref, index, index, index, index, index + + return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that for `memref.get_global` -> `memref.extract_strided_metadata` does not +// resolve when the strides are not identity. This is an unhandled case that could +// be covered in the future + +// CHECK-LABEL: func @extract_strided_metadata_of_get_global_with_strides() +// CHECK: %[[GET_GLOBAL:.+]] = memref.get_global @const_i32 +// CHECK: memref.extract_strided_metadata %[[GET_GLOBAL]] +memref.global "private" constant @const_i32 : memref<512x384xi32, strided<[420, 1], offset: 0>> = dense<42> + +func.func @extract_strided_metadata_of_get_global_with_strides() + -> (memref, index, index, index, index, index) { + + %A = memref.get_global @const_i32 : memref<512x384xi32, strided<[420, 1], offset: 0>> + + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A : + memref<512x384xi32, strided<[420, 1], offset: 0>> + -> memref, index, index, index, index, index + + return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that for `memref.get_global` -> `memref.extract_strided_metadata` does not +// resolve when the offset is non-zero. This is an unhandled case that could +// be covered in the future + +// CHECK-LABEL: func @extract_strided_metadata_of_get_global_with_offset() +// CHECK: %[[GET_GLOBAL:.+]] = memref.get_global @const_i32 +// CHECK: memref.extract_strided_metadata %[[GET_GLOBAL]] +memref.global "private" constant @const_i32 : memref<512x384xi32, strided<[384, 1], offset: 20>> = dense<42> + +func.func @extract_strided_metadata_of_get_global_with_offset() + -> (memref, index, index, index, index, index) { + + %A = memref.get_global @const_i32 : memref<512x384xi32, strided<[384, 1], offset: 20>> + + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A : + memref<512x384xi32, strided<[384, 1], offset: 20>> + -> memref, index, index, index, index, index + + return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +}