diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -0,0 +1,40 @@ +//===- Transforms.h - MemRef Dialect transformations ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// This header declares functions that assit transformations in the MemRef +/// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H + +namespace mlir { +class RewritePatternSet; + +namespace memref { +/// Appends patterns for extracting address computations from the instructions +/// with memory accesses such that these memory accesses use only a base +/// pointer. +/// +/// For instance, +/// ```mlir +/// memref.load %base[%off0, ...] +/// ``` +/// +/// Will be rewritten in: +/// ```mlir +/// %new_base = memref.subview %base[%off0,...][1,...][1,...] +/// memref.load %new_base[%c0,...] +/// ``` +void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns); + +} // namespace memref +} // namespace mlir + +#endif diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ ExpandOps.cpp ExpandStridedMetadata.cpp EmulateWideInt.cpp + ExtractAddressComputations.cpp FoldMemRefAliasOps.cpp MultiBuffer.cpp NormalizeMemRefs.cpp @@ -27,6 +28,7 @@ MLIRInferTypeOpInterface MLIRLoopLikeInterface MLIRMemRefDialect + MLIRNVGPUDialect MLIRPass MLIRTensorDialect MLIRTransforms diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -0,0 +1,307 @@ +//===- ExtractAddressCmoputations.cpp - Extract address computations -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// This transformation pass rewrites loading/storing from/to a memref with +/// offsets into loading/storing from/to a subview and without any offset on +/// the instruction itself. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper functions for the `load base[off0...]` +// => `load (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches getFailureOrSrcMemRef specs for LoadOp. +// \see LoadStoreLikeOpRewriter. +static FailureOr getLoadOpSrcMemRef(memref::LoadOp loadOp) { + return loadOp.getMemRef(); +} + +// Matches rebuildOpFromAddressAndIndices specs for LoadOp. +// \see LoadStoreLikeOpRewriter. +static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter, + memref::LoadOp loadOp, Value srcMemRef, + ArrayRef indices) { + Location loc = loadOp.getLoc(); + return rewriter.create(loc, srcMemRef, indices, + loadOp.getNontemporal()); +} + +// Matches getViewSizeForEachDim specs for LoadOp. +// \see LoadStoreLikeOpRewriter. +static SmallVector +getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) { + MemRefType ldTy = loadOp.getMemRefType(); + unsigned loadRank = ldTy.getRank(); + return SmallVector(loadRank, rewriter.getIndexAttr(1)); +} + +//===----------------------------------------------------------------------===// +// Helper functions for the `store val, base[off0...]` +// => `store val, (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches getFailureOrSrcMemRef specs for StoreOp. +// \see LoadStoreLikeOpRewriter. +static FailureOr getStoreOpSrcMemRef(memref::StoreOp storeOp) { + return storeOp.getMemRef(); +} + +// Matches rebuildOpFromAddressAndIndices specs for StoreOp. +// \see LoadStoreLikeOpRewriter. +static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter, + memref::StoreOp storeOp, Value srcMemRef, + ArrayRef indices) { + Location loc = storeOp.getLoc(); + return rewriter.create(loc, storeOp.getValueToStore(), + srcMemRef, indices, + storeOp.getNontemporal()); +} + +// Matches getViewSizeForEachDim specs for StoreOp. +// \see LoadStoreLikeOpRewriter. +static SmallVector +getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) { + MemRefType ldTy = storeOp.getMemRefType(); + unsigned loadRank = ldTy.getRank(); + return SmallVector(loadRank, rewriter.getIndexAttr(1)); +} + +//===----------------------------------------------------------------------===// +// Helper functions for the `ldmatrix base[off0...]` +// => `ldmatrix (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches getFailureOrSrcMemRef specs for LdMatrixOp. +// \see LoadStoreLikeOpRewriter. +static FailureOr getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) { + return ldMatrixOp.getSrcMemref(); +} + +// Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp. +// \see LoadStoreLikeOpRewriter. +static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter, + nvgpu::LdMatrixOp ldMatrixOp, + Value srcMemRef, + ArrayRef indices) { + Location loc = ldMatrixOp.getLoc(); + return rewriter.create( + loc, ldMatrixOp.getResult().getType(), srcMemRef, indices, + ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles()); +} + +//===----------------------------------------------------------------------===// +// Helper functions for the `transfer_read base[off0...]` +// => `transfer_read (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches getFailureOrSrcMemRef specs for TransferReadOp. +// \see LoadStoreLikeOpRewriter. +template +static FailureOr +getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) { + Value src = transferLikeOp.getSource(); + if (src.getType().isa()) + return src; + return failure(); +} + +// Matches rebuildOpFromAddressAndIndices specs for TransferReadOp. +// \see LoadStoreLikeOpRewriter. +static vector::TransferReadOp +rebuildTransferReadOp(RewriterBase &rewriter, + vector::TransferReadOp transferReadOp, Value srcMemRef, + ArrayRef indices) { + Location loc = transferReadOp.getLoc(); + return rewriter.create( + loc, transferReadOp.getResult().getType(), srcMemRef, indices, + transferReadOp.getPermutationMap(), transferReadOp.getPadding(), + transferReadOp.getMask(), transferReadOp.getInBoundsAttr()); +} + +//===----------------------------------------------------------------------===// +// Helper functions for the `transfer_write base[off0...]` +// => `transfer_write (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp. +// \see LoadStoreLikeOpRewriter. +static vector::TransferWriteOp +rebuildTransferWriteOp(RewriterBase &rewriter, + vector::TransferWriteOp transferWriteOp, Value srcMemRef, + ArrayRef indices) { + Location loc = transferWriteOp.getLoc(); + return rewriter.create( + loc, transferWriteOp.getValue(), srcMemRef, indices, + transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(), + transferWriteOp.getInBoundsAttr()); +} + +//===----------------------------------------------------------------------===// +// Generic helper functions used as default implementation in +// LoadStoreLikeOpRewriter. +//===----------------------------------------------------------------------===// + +/// Helper function to get the src memref. +/// It uses the already defined getFailureOrSrcMemRef but asserts +/// that the source is a memref. +template (*getFailureOrSrcMemRef)(LoadStoreLikeOp)> +static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) { + FailureOr failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp); + assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used"); + return *failureOrSrcMemRef; +} + +/// Helper function to get the sizes of the resulting view. +/// This function gets the sizes of the source memref then substracts the +/// offsets used within \p loadStoreLikeOp. This gives the maximal (for +/// inbound) sizes for the view. +/// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp. +template +static SmallVector +getGenericOpViewSizeForEachDim(RewriterBase &rewriter, + LoadStoreLikeOp loadStoreLikeOp) { + Location loc = loadStoreLikeOp.getLoc(); + auto extractStridedMetadataOp = + rewriter.create( + loc, getSrcMemRef(loadStoreLikeOp)); + SmallVector srcSizes = + extractStridedMetadataOp.getConstifiedMixedSizes(); + SmallVector indices = + getAsOpFoldResult(loadStoreLikeOp.getIndices()); + SmallVector finalSizes; + + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + + for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) { + finalSizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, s0 - s1, + {srcSize, indice})); + } + return finalSizes; +} + +/// Rewrite a store/load-like op so that all its indices are zeros. +/// E.g., %ld = memref.load %base[%off0]...[%offN] +/// => +/// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1] +/// %ld = memref.load %new_base[0,..,0] : +/// memref<1x..x1xTy, strided<[1,..,1], offset: ?>> +/// +/// `getSrcMemRef` returns the source memref for the given load-like operation. +/// +/// `getViewSizeForEachDim` returns the sizes of view that is going to feed +/// new operation. This must return one size per dimension of the view. +/// The sizes of the view needs to be at least as big as what is actually +/// going to be accessed. Use the provided `loadStoreOp` to get the right +/// sizes. +/// +/// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new +/// LoadStoreLikeOp that reads from srcMemRef[indices]. +/// The returned operation will be used to replace loadStoreOp. +template (*getFailureOrSrcMemRef)(LoadStoreLikeOp), + LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)( + RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/, + Value /*srcMemRef*/, ArrayRef /*indices*/), + SmallVector (*getViewSizeForEachDim)( + RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) = + getGenericOpViewSizeForEachDim< + LoadStoreLikeOp, + getSrcMemRef>> +struct LoadStoreLikeOpRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp, + PatternRewriter &rewriter) const override { + FailureOr failureOrSrcMemRef = + getFailureOrSrcMemRef(loadStoreLikeOp); + if (failed(failureOrSrcMemRef)) + return failure(); + Value srcMemRef = *failureOrSrcMemRef; + auto ldStTy = srcMemRef.getType().cast(); + unsigned loadStoreRank = ldStTy.getRank(); + // Don't waste compile time if there is nothing to rewrite. + if (loadStoreRank == 0) + return failure(); + + // If our load already has only zeros as indices there is nothing + // to do. + SmallVector indices = + getAsOpFoldResult(loadStoreLikeOp.getIndices()); + if (std::all_of(indices.begin(), indices.end(), + [](const OpFoldResult &opFold) { + return isConstantIntValue(opFold, 0); + })) { + return failure(); + } + + // Create the array of ones of the right size. + SmallVector ones(loadStoreRank, rewriter.getIndexAttr(1)); + SmallVector sizes = + getViewSizeForEachDim(rewriter, loadStoreLikeOp); + assert(sizes.size() == loadStoreRank && + "Expected one size per load dimension"); + Location loc = loadStoreLikeOp.getLoc(); + auto subview = + rewriter.create(loc, /*source=*/srcMemRef, + /*offsets=*/indices, + /*sizes=*/sizes, /*strides=*/ones); + // Rewrite the load/store with the subview as the base pointer. + SmallVector zeros(loadStoreRank, + rewriter.create(loc, 0)); + LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices( + rewriter, loadStoreLikeOp, subview.getResult(), zeros); + rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults()); + return success(); + } +}; +} // namespace + +void memref::populateExtractAddressComputationsPatterns( + RewritePatternSet &patterns) { + patterns.add< + LoadStoreLikeOpRewriter< + memref::LoadOp, + /*getSrcMemRef=*/getLoadOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp, + /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>, + LoadStoreLikeOpRewriter< + memref::StoreOp, + /*getSrcMemRef=*/getStoreOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp, + /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>, + LoadStoreLikeOpRewriter< + nvgpu::LdMatrixOp, + /*getSrcMemRef=*/getLdMatrixOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>, + LoadStoreLikeOpRewriter< + vector::TransferReadOp, + /*getSrcMemRef=*/getTransferLikeOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>, + LoadStoreLikeOpRewriter< + vector::TransferWriteOp, + /*getSrcMemRef=*/getTransferLikeOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>( + patterns.getContext()); +} diff --git a/mlir/test/Dialect/MemRef/extract-address-computations.mlir b/mlir/test/Dialect/MemRef/extract-address-computations.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/extract-address-computations.mlir @@ -0,0 +1,289 @@ +// RUN: mlir-opt -test-extract-address-computations %s --split-input-file | FileCheck %s + +// Simple test: check that we extract the address computation of a load into +// a dedicated subview. +// The resulting load will be loading from the subview and have only indices +// set to zero. + +// CHECK-LABEL: @test_load( +// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}}, +// CHECK-SAME: %[[DYN_OFFSET:.*]]: index) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: return %[[LOADED_VAL]] : f32 +func.func @test_load(%base : memref<2x16x16xf32>, %offset : index) -> f32 { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %loaded_val = memref.load %base[%offset, %c0, %c8] : memref<2x16x16xf32> + return %loaded_val : f32 +} + +// ----- + +// Same as previous @test_load but with the nontemporal flag. + +// CHECK-LABEL: @test_load_nontemporal( +// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}}, +// CHECK-SAME: %[[DYN_OFFSET:.*]]: index) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {nontemporal = true} : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: return %[[LOADED_VAL]] : f32 +func.func @test_load_nontemporal(%base : memref<2x16x16xf32>, %offset : index) -> f32 { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %loaded_val = memref.load %base[%offset, %c0, %c8] {nontemporal = true } : memref<2x16x16xf32> + return %loaded_val : f32 +} + +// ----- + +// Simple test: check that we extract the address computation of a store into +// a dedicated subview. +// The resulting store will use the address from the subview and have only +// indices set to zero. + +// CHECK-LABEL: @test_store( +// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}}, +// CHECK-SAME: %[[DYN_OFFSET:.*]]: index) +// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: memref.store %[[CF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: return +func.func @test_store(%base : memref<2x16x16xf32>, %offset : index) -> () { + %cf0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + memref.store %cf0, %base[%offset, %c0, %c8] : memref<2x16x16xf32> + return +} + +// ----- + +// Same as @test_store but check that the nontemporal flag is preserved. + +// CHECK-LABEL: @test_store_nontemporal( +// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}}, +// CHECK-SAME: %[[DYN_OFFSET:.*]]: index) +// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: memref.store %[[CF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {nontemporal = true} : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: return +func.func @test_store_nontemporal(%base : memref<2x16x16xf32>, %offset : index) -> () { + %cf0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + memref.store %cf0, %base[%offset, %c0, %c8] { nontemporal = true } : memref<2x16x16xf32> + return +} + +// ----- +// For this test, we made the source memref fully dynamic. +// The gist of the check remains the same as the simple test: +// The address computation is extracted into its own subview. +// CHECK-LABEL: @testWithLoop( +// CHECK-SAME: %[[BASE:[^:]*]]: memref +// CHECK: %[[SUM_ALL:.*]] = arith.constant 0.0{{0*e\+00}} : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[UPPER_BOUND0:.*]] = memref.dim %[[BASE]], %[[C0]] : memref (f32) { +// CHECK: %[[SUM_RES1:.*]] = scf.for %[[IV1:.*]] = %[[C0]] to %[[UPPER_BOUND1]] step %[[C1]] iter_args(%[[SUM_ITER1:.*]] = %[[SUM_ITER2]]) -> (f32) { +// CHECK: %[[SUM_RES0:.*]] = scf.for %[[IV0:.*]] = %[[C0]] to %[[UPPER_BOUND0]] step %[[C1]] iter_args(%[[SUM_ITER0:.*]] = %[[SUM_ITER1]]) -> (f32) { +// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[IV0]], %[[IV1]], %[[IV2]]] [1, 1, 1] [1, 1, 1] : memref> to memref<1x1x1xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: %[[RES:.*]] = arith.addf %[[LOADED_VAL]], %[[SUM_ITER2]] : f32 +// CHECK: scf.yield %[[RES]] : f32 +// CHECK: } +// CHECK: scf.yield %[[SUM_RES0]] : f32 +// CHECK: } +// CHECK: scf.yield %[[SUM_RES1]] : f32 +// CHECK: } +// CHECK: return %[[SUM_RES2]] : f32 +func.func @testWithLoop(%base : memref>) -> f32 { + %sum_all = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %upper_bound0 = memref.dim %base, %c0 : memref> + %upper_bound1 = memref.dim %base, %c1 : memref> + %upper_bound2 = memref.dim %base, %c2 : memref> + %sum_res2 = scf.for %iv2 = %c0 to %upper_bound2 step %c1 iter_args(%sum_iter2 = %sum_all) -> (f32) { + %sum_res1 = scf.for %iv1 = %c0 to %upper_bound1 step %c1 iter_args(%sum_iter1 = %sum_iter2) -> (f32) { + %sum_res0 = scf.for %iv0 = %c0 to %upper_bound0 step %c1 iter_args(%sum_iter0 = %sum_iter1) -> (f32) { + %loaded_val = memref.load %base[%iv0, %iv1, %iv2] : memref> + %res = arith.addf %loaded_val, %sum_iter2 : f32 + scf.yield %res : f32 + } + scf.yield %sum_res0 : f32 + } + scf.yield %sum_res1 : f32 + } + return %sum_res2 : f32 +} + +// ----- + +// Simple test: check that we extract the address computation of a ldmatrix into +// a dedicated subview. +// The resulting ldmatrix will loaded from with subview and have only indices set +// to zero. +// Also the sizes of the view are adjusted to `original size - offset`. + +// CHECK-DAG: #[[$FOUR_MINUS_OFF_MAP:.*]] = affine_map<()[s0] -> (-s0 + 4)> +// CHECK-DAG: #[[$THIRTY_TWO_MINUS_OFF_MAP:.*]] = affine_map<()[s0] -> (-s0 + 32)> +// CHECK-LABEL: @test_ldmatrix( +// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}, 3>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$FOUR_MINUS_OFF_MAP]]()[%[[DYN_OFFSET0]]] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$THIRTY_TWO_MINUS_OFF_MAP]]()[%[[DYN_OFFSET1]]] +// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$THIRTY_TWO_MINUS_OFF_MAP]]()[%[[DYN_OFFSET2]]] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<4x32x32xf16, 3> to memref, 3> +// CHECK: %[[LOADED_VAL:.*]] = nvgpu.ldmatrix %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {numTiles = 4 : i32, transpose = false} : memref, 3> -> vector<4x2xf16> +// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16> +func.func @test_ldmatrix(%base : memref<4x32x32xf16, 3>, + %offset0 : index, %offset1: index, %offset2: index) + -> vector<4x2xf16> { + %loaded_val = nvgpu.ldmatrix + %base[%offset0, %offset1, %offset2] + {numTiles = 4 : i32, transpose = false} + : memref<4x32x32xf16, 3> -> vector<4x2xf16> + return %loaded_val : vector<4x2xf16> +} + +// ----- + +// Same as test_ldmatrix but with fully dynamic memref. + +// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)> +// CHECK-LABEL: @test_ldmatrix( +// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}, 3>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]] +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]] +// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref to memref, 3> +// CHECK: %[[LOADED_VAL:.*]] = nvgpu.ldmatrix %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {numTiles = 4 : i32, transpose = false} : memref, 3> -> vector<4x2xf16> +// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16> +func.func @test_ldmatrix(%base : memref, + %offset0 : index, %offset1: index, %offset2: index) + -> vector<4x2xf16> { + %loaded_val = nvgpu.ldmatrix + %base[%offset0, %offset1, %offset2] + {numTiles = 4 : i32, transpose = false} + : memref -> vector<4x2xf16> + return %loaded_val : vector<4x2xf16> +} + +// ----- + +// Simple test for vector.transfer_read with fully dynamic memref. +// We also set a permutation map to make sure it is properly preserved. + +// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)> +// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK-LABEL: @test_transfer_read_op( +// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]] +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]] +// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f16 +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref to memref> +// CHECK: %[[LOADED_VAL:.*]] = vector.transfer_read %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]], %[[CF0]] {permutation_map = #[[$PERMUTATION_MAP]]} : memref>, vector<4x2xf16> +// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16> +func.func @test_transfer_read_op(%base : memref, + %offset0 : index, %offset1: index, %offset2: index) + -> vector<4x2xf16> { + %cf0 = arith.constant 0.0 : f16 + %loaded_val = vector.transfer_read %base[%offset0, %offset1, %offset2], %cf0 { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : memref, vector<4x2xf16> + return %loaded_val : vector<4x2xf16> +} + +// ----- + +// Same as test_transfer_read_op but with tensors. +// Right now this rewrite is not supported but we still shouldn't choke on it. + +// CHECK: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK-LABEL: @test_transfer_read_op_with_tensor( +// CHECK-SAME: %[[BASE:[^:]*]]: tensor<{{[^,]*}}>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f16 +// CHECK: %[[LOADED_VAL:.*]] = vector.transfer_read %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]], %[[CF0]] {permutation_map = #[[$PERMUTATION_MAP]]} : tensor, vector<4x2xf16> +// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16> +func.func @test_transfer_read_op_with_tensor(%base : tensor, + %offset0 : index, %offset1: index, %offset2: index) + -> vector<4x2xf16> { + %cf0 = arith.constant 0.0 : f16 + %loaded_val = vector.transfer_read %base[%offset0, %offset1, %offset2], %cf0 { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : tensor, vector<4x2xf16> + return %loaded_val : vector<4x2xf16> +} + +// ----- + +// Simple test for vector.transfer_write with fully dynamic memref. +// We also set a permutation map to make sure it is properly preserved. + +// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)> +// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK-LABEL: @test_transfer_write_op( +// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]] +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]] +// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16> +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref to memref> +// CHECK: vector.transfer_write %[[VCF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, memref> +// CHECK: return +func.func @test_transfer_write_op(%base : memref, + %offset0 : index, %offset1: index, %offset2: index) { + %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16> + vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, memref + return +} + +// ----- + +// Same as test_transfer_write_op but with tensors. +// Right now this rewrite is not supported but we still shouldn't choke on it. + +// CHECK: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK-LABEL: @test_transfer_write_op_with_tensor( +// CHECK-SAME: %[[BASE:[^:]*]]: tensor<{{[^,]*}}>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16> +// CHECK: %[[RES:.*]] = vector.transfer_write %[[VCF0]], %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, tensor +// CHECK: return %[[RES]] : tensor +func.func @test_transfer_write_op_with_tensor(%base : tensor, + %offset0 : index, %offset1: index, %offset2: index) -> tensor { + %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16> + %res = vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, tensor + return %res : tensor +} diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -1,14 +1,17 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMemRefTestPasses TestComposeSubView.cpp + TestExtractAddressComputations.cpp TestMultiBuffer.cpp EXCLUDE_FROM_LIBMLIR LINK_LIBS PUBLIC MLIRPass + MLIRAffineDialect MLIRMemRefDialect MLIRMemRefTransforms + MLIRNVGPUDialect MLIRTestDialect ) diff --git a/mlir/test/lib/Dialect/MemRef/TestExtractAddressComputations.cpp b/mlir/test/lib/Dialect/MemRef/TestExtractAddressComputations.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/TestExtractAddressComputations.cpp @@ -0,0 +1,65 @@ +//===- TestExtractAddressComputations.cpp - Test extract addr computations-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test extract address computations patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define PASS_NAME "test-extract-address-computations" + +using namespace mlir; + +namespace { + +struct TestExtractAddressComputations + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExtractAddressComputations) + + StringRef getArgument() const final { return PASS_NAME; } + StringRef getDescription() const final { + return "Tests extract address computations patterns."; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + TestExtractAddressComputations() = default; + TestExtractAddressComputations(const TestExtractAddressComputations &pass) + : PassWrapper(pass){}; + + void runOnOperation() override; +}; + +} // namespace + +void TestExtractAddressComputations::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateExtractAddressComputationsPatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} + +namespace mlir { +namespace test { +void registerTestExtractAddressComputations() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -84,6 +84,7 @@ void registerTestDominancePass(); void registerTestDynamicPipelinePass(); void registerTestExpandMathPass(); +void registerTestExtractAddressComputations(); void registerTestFooAnalysisPass(); void registerTestComposeSubView(); void registerTestMultiBuffering(); @@ -196,6 +197,7 @@ mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); mlir::test::registerTestExpandMathPass(); + mlir::test::registerTestExtractAddressComputations(); mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10096,6 +10096,7 @@ ":LoopLikeInterface", ":MemRefDialect", ":MemRefPassIncGen", + ":NVGPUDialect", ":Pass", ":RuntimeVerifiableOpInterface", ":TensorDialect",