diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -148,6 +148,9 @@ /// easier to reason about operations. std::unique_ptr createExpandStridedMetadataPass(); +/// TODO +std::unique_ptr createRewriteAddressComputationPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -202,5 +202,15 @@ "AffineDialect", "memref::MemRefDialect" ]; } +def RewriteAddressComputation : Pass<"rewrite-address-computation"> { + let summary = "TODO"; + let description = [{ + TODO + }]; + let constructor = "mlir::memref::createRewriteAddressComputationPass()"; + let dependentDialects = [ + "AffineDialect", "memref::MemRefDialect" + ]; +} #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ MultiBuffer.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp + RewriteAddressComputation.cpp RuntimeOpVerification.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/MemRef/Transforms/RewriteAddressComputation.cpp b/mlir/lib/Dialect/MemRef/Transforms/RewriteAddressComputation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/RewriteAddressComputation.cpp @@ -0,0 +1,89 @@ +//===- RewriteAddressComputation.cpp - Rewrite address computation -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass rewrites load/store operations into subviews + +// load/store such that the offsets of resulting load/store are zeros. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "rewrite-address-computation" + +namespace mlir { +namespace memref { +#define GEN_PASS_DEF_REWRITEADDRESSCOMPUTATION +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +} // namespace memref +} // namespace mlir +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { + +struct RewriteAddressComputationPass final + : public memref::impl::RewriteAddressComputationBase< + RewriteAddressComputationPass> { + void runOnOperation() override; +}; + +} // namespace + +// Rewrite a load so that all its indices are zeros. +// E.g., %ld = memref.load %base[%off0]...[%offN] +// => +// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1] +// %ld = memref.load %new_base[0,..,0] : +// memref<1x..x1xTy, strided<[1,..,1], offset: ?>> +// +// Ultimately we want to produce an affine map with the address computation. +// This will be taken care of by the expand-strided-metadata pass. +static void rewriteLoad(RewriterBase &rewriter, memref::LoadOp loadOp) { + MemRefType ldTy = loadOp.getMemRefType(); + unsigned loadRank = ldTy.getRank(); + // Don't waste compile time if there is nothing to rewrite. + if (loadRank == 0) + return; + + RewriterBase::InsertionGuard guard(RewriteAddressComputationPass); + rewriter.setInsertionPoint(loadOp); + // Create the array of ones of the right size. + SmallVector ones(loadRank, rewriter.getIndexAttr(1)); + Location loc = loadOp.getLoc(); + auto subview = rewriter.create( + loc, /*source=*/loadOp.getMemRef(), + /*offsets=*/getAsOpFoldResult(loadOp.getIndices()), + /*sizes=*/ones, /*strides=*/ones); + // Rewrite the load with the subview as the base pointer. + SmallVector zeros(loadRank, + rewriter.create(loc, 0)); + auto newLoad = rewriter.create(loc, subview.getResult(), + /*indices=*/zeros); + rewriter.replaceOp(loadOp, newLoad.getResult()); +} + +void RewriteAddressComputationPass::runOnOperation() { + Operation *funcOp = getOperation(); + IRRewriter rewriter(&getContext()); + funcOp->walk([&](memref::LoadOp loadOp) { + LLVM_DEBUG(llvm::dbgs() << "Found load:\n" << loadOp << '\n'); + rewriteLoad(rewriter, loadOp); + }); +} + +std::unique_ptr memref::createRewriteAddressComputationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/MemRef/rewrite-address-computation.mlir b/mlir/test/Dialect/MemRef/rewrite-address-computation.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/rewrite-address-computation.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -rewrite-address-computation %s --split-input-file | FileCheck %s +// TODO: run with expand-strided-metadata + decompose affine + +// TODO: check the lowering to affine map. +// The resulting address computation is: +// %offset * 16 * 16 + 0 * 16 + 8 + +// CHECK-LABEL: @test +// CHECK-SAME: (%[[BASE:.*]]: memref{{[^,]*}}, +// CHECK-SAME: %[[DYN_OFFSET:.*]]: index) +// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: return %[[LOADED_VAL]] : f32 +func.func @test(%base : memref<2x16x16xf32>, %offset : index) -> f32 { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %loaded_val = memref.load %base[%offset, %c0, %c8] : memref<2x16x16xf32> + return %loaded_val : f32 +} + +// ----- + +// Note: the scf.for are purposely flipped (dim2 -> dim0 instead of dim0 -> dim2) to +// make the ordering from the decompose of affine ops more obvious. +func.func @testWithLoop(%base : memref>) -> f32 { + %sum_all = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %upper_bound0 = memref.dim %base, %c0 : memref> + %upper_bound1 = memref.dim %base, %c1 : memref> + %upper_bound2 = memref.dim %base, %c2 : memref> + %sum_res2 = scf.for %iv2 = %c0 to %upper_bound2 step %c1 iter_args(%sum_iter2 = %sum_all) -> (f32) { + %sum_res1 = scf.for %iv1 = %c0 to %upper_bound1 step %c1 iter_args(%sum_iter1 = %sum_iter2) -> (f32) { + %sum_res0 = scf.for %iv0 = %c0 to %upper_bound0 step %c1 iter_args(%sum_iter0 = %sum_iter1) -> (f32) { + %loaded_val = memref.load %base[%iv0, %iv1, %iv2] : memref> + %res = arith.addf %loaded_val, %sum_iter2 : f32 + scf.yield %res : f32 + } + scf.yield %sum_res0 : f32 + } + scf.yield %sum_res1 : f32 + } + return %sum_res2 : f32 +} + +