diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" @@ -31,16 +32,31 @@ using namespace mlir; using namespace presburger; +/// Value to pass for position when the reduction in the loop +/// happens through a memref. +static constexpr unsigned kMemRefReduction = ~0u; + /// Get the value that is being reduced by `pos`-th reduction in the loop if /// such a reduction can be performed by affine parallel loops. This assumes /// floating-point operations are commutative. On success, `kind` will be the /// reduction kind suitable for use in affine parallel loop builder. If the -/// reduction is not supported, returns null. -static Value getSupportedReduction(AffineForOp forOp, unsigned pos, - arith::AtomicRMWKind &kind) { +/// reduction is not supported, returns null. For memref based reductions +/// `memRefCombinerOps` carries the combiner ops used to accumulate into +/// the memref `memRefReducedVal`. +static Value +getSupportedReduction(AffineForOp forOp, unsigned pos, + arith::AtomicRMWKind &kind, + const SmallVectorImpl *memRefCombinerOps, + Value memRefReducedVal) { SmallVector combinerOps; - Value reducedVal = - matchReduction(forOp.getRegionIterArgs(), pos, combinerOps); + Value reducedVal; + if (forOp.getNumRegionIterArgs() > 0) { + reducedVal = matchReduction(forOp.getRegionIterArgs(), pos, combinerOps); + } else { + // Memref based reduction. Initialize the combiner ops. + combinerOps.assign(memRefCombinerOps->begin(), memRefCombinerOps->end()); + reducedVal = memRefReducedVal; + } if (!reducedVal) return nullptr; @@ -75,17 +91,89 @@ return reducedVal; } +/// Matcher function that looks for memory based reduction and +/// returns the memrefs being used for reduction and the associated +/// combiner ops. +static void +matchReductionViaMemRef(Region &loopBodyBlock, + SmallVectorImpl &allCombinerOps, + SmallVectorImpl &reductionMemRefs) { + DenseMap potentialRedStore; + const auto &ops = loopBodyBlock.getOps(); + for (Operation &op : ops) { + AffineWriteOpInterface storeOp = dyn_cast(op); + if (!storeOp) + continue; + auto storeMemRef = storeOp.getMemRef(); + auto defOp = storeMemRef.getDefiningOp(); + if (!defOp || (defOp && !isa(defOp))) + continue; + if (potentialRedStore.find(storeMemRef) == potentialRedStore.end()) { + potentialRedStore[storeMemRef] = storeOp; + } else { + // Only 1 store expected. + potentialRedStore[storeMemRef] = nullptr; + } + } + // Go over all the potential stores to see if there is match + // for the following reduction pattern + // val = load from reduction memref + // reduced-val = combinerOp(val) + // store reduced-val to reduction memref + for (auto pair : potentialRedStore) { + auto *combinerOp = pair.second.getValueToStore().getDefiningOp(); + Operation *redRegionOp = pair.second->getParentOp(); + if (!combinerOp || !MemoryEffectOpInterface::hasNoEffect(combinerOp) || + combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() || + combinerOp->getParentOp() != redRegionOp || + combinerOp->getNumOperands() != 2) + continue; + // Limit to seeing only one combiner op since the callee + // only expects one. + auto loadOp1 = dyn_cast( + combinerOp->getOperand(0).getDefiningOp()); + auto loadOp2 = dyn_cast( + combinerOp->getOperand(1).getDefiningOp()); + if ((loadOp1 && loadOp1.getMemRef() == pair.first) || + (loadOp2 && loadOp2.getMemRef() == pair.first)) { + reductionMemRefs.push_back(pair.first); + allCombinerOps.push_back(combinerOp); + } + } + return; +} + /// Populate `supportedReductions` with descriptors of the supported reductions. void mlir::getSupportedReductions( AffineForOp forOp, SmallVectorImpl &supportedReductions) { unsigned numIterArgs = forOp.getNumIterOperands(); - if (numIterArgs == 0) - return; - supportedReductions.reserve(numIterArgs); - for (unsigned i = 0; i < numIterArgs; ++i) { + unsigned numReductions = 0; + bool reductionViaMemRef = false; + SmallVector allReducedVals; + SmallVector allCombinerOps; + if (numIterArgs == 0) { + // Possibly memory based reductions. Find all the memref's being + // used for these reductions and the associated combiner ops. + matchReductionViaMemRef(forOp.getLoopBody(), allCombinerOps, + allReducedVals); + reductionViaMemRef = true; + numReductions = allReducedVals.size(); + } else { + numReductions = numIterArgs; + } + supportedReductions.reserve(numReductions); + for (unsigned i = 0; i < numReductions; ++i) { arith::AtomicRMWKind kind; - if (Value value = getSupportedReduction(forOp, i, kind)) - supportedReductions.emplace_back(LoopReduction{kind, i, value}); + if (reductionViaMemRef) { + SmallVector combinerOps(1, allCombinerOps[i]); + if (Value value = getSupportedReduction(forOp, i, kind, &combinerOps, + allReducedVals[i])) + supportedReductions.emplace_back( + LoopReduction{kind, kMemRefReduction, value}); + } else { + if (Value value = getSupportedReduction(forOp, i, kind, nullptr, nullptr)) + supportedReductions.emplace_back(LoopReduction{kind, i, value}); + } } } @@ -106,7 +194,7 @@ getSupportedReductions(forOp, *parallelReductions); // Return later to allow for identifying all parallel reductions even if the // loop is not parallel. - if (parallelReductions->size() != numIterArgs) + if (numIterArgs != 0 && parallelReductions->size() != numIterArgs) return false; } diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -1110,6 +1110,21 @@ auto getSliceLoop = [&](unsigned i) { return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i]; }; + if (isa(depSourceOp) && + isa(depSinkOp)) { + // Slices of source loop nests that contain reductions are reset + // to their original bounds.. + for (unsigned i = 0; i < numSliceLoopIVs; ++i) { + AffineForOp srcForOp = getSliceLoop(i); + SmallVector reductions; + bool forOpIsParallel = isLoopParallel(srcForOp, &reductions); + (void)forOpIsParallel; + if (!reductions.empty()) { + getSequentialLoops(srcForOp, &sequentialLoops); + break; + } + } + } auto isInnermostInsertion = [&]() { return (isBackwardSlice ? loopDepth >= srcLoopIVs.size() : loopDepth >= dstLoopIVs.size()); diff --git a/mlir/test/Transforms/loop-fusion-4.mlir b/mlir/test/Transforms/loop-fusion-4.mlir --- a/mlir/test/Transforms/loop-fusion-4.mlir +++ b/mlir/test/Transforms/loop-fusion-4.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="mode=producer" -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER // RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal mode=sibling" -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL +// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal mode=producer" -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL // Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir @@ -141,3 +142,53 @@ // SIBLING-MAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 { // SIBLING-MAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) { // SIBLING-MAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) { + +// ----- + +// Source loop nest %i1 is a reduction but due to fusion of preceding +// preceding loop %i0 generates a producer for %i3. Check if the +// producer fusion happens at depth 2 with original bounds. +// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @reduction_producer_consumer( +func.func @reduction_producer_consumer(%arg0: memref<1024xf32, 1>, %arg1: memref<1xf32, 1>, %arg2: memref<1xf32, 1>, %arg3: memref<1xf32, 1>) { + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.alloc() : memref + %1 = memref.alloc() : memref + %2 = memref.alloc() : memref<1024xf32, 1> + affine.for %i0 = 0 to 1024 { + %4 = affine.load %arg0[%i0] : memref<1024xf32, 1> + %6 = arith.addf %4, %4 : f32 + affine.store %6, %2[%i0] : memref<1024xf32, 1> + } + affine.for %i1 = 0 to 1 { + affine.store %cst, %1[] : memref + affine.for %i2 = 0 to 1024 { + %5 = affine.load %1[] : memref + %6 = affine.load %2[%i2] : memref<1024xf32, 1> + %7 = arith.addf %5, %6 : f32 + affine.store %7, %1[] : memref + } + %4 = affine.load %1[] : memref + affine.store %4, %arg2[%i1] : memref<1xf32, 1> + } + affine.for %i3 = 0 to 1 { + affine.store %cst, %0[] : memref + affine.for %i4 = 0 to 1024 { + %5 = affine.load %0[] : memref + %6 = affine.load %2[%i4] : memref<1024xf32, 1> + %7 = arith.addf %5, %6 : f32 + affine.store %7, %0[] : memref + } + %4 = affine.load %0[] : memref + affine.store %4, %arg3[%i3] : memref<1xf32, 1> + } + return +} +// PRODUCER-CONSUMER-MAXIMAL: %[[cst0:.*]] = arith.constant 0.000000e+00 : f32 +// PRODUCER-CONSUMER-MAXIMAL: %[[tmp0:.*]] = memref.alloc() : memref +// PRODUCER-CONSUMER-MAXIMAL: %[[tmp1:.*]] = memref.alloc() : memref +// PRODUCER-CONSUMER-MAXIMAL: affine.for %[[i0:.*]] = 0 to 1 { +// PRODUCER-CONSUMER-MAXIMAL: affine.store %[[cst0]], %[[tmp0]][] : memref +// PRODUCER-CONSUMER-MAXIMAL: affine.for %[[i1:.*]] = 0 to 1024 { +// PRODUCER-CONSUMER-MAXIMAL: affine.store %[[cst0]], %[[tmp1]][] : memref +// Producer reduction is inserted with original bounds +// PRODUCER-CONSUMER-MAXIMAL: affine.for %[[i2:.*]] = 0 to 1024 {