diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -54,6 +54,18 @@ void getSequentialLoops(AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops); +/// Enumerates different result statuses of slice computation by +/// `computeSliceUnion` +// TODO: Identify and add different kinds of failures during slice computation. +struct SliceComputationResult { + enum ResultEnum { + Success, + IncorrectSliceFailure, // Slice is computed, but it is incorrect. + GenericFailure, // Unable to compute src loop computation slice. + } value; + SliceComputationResult(ResultEnum v) : value(v) {} +}; + /// ComputationSliceState aggregates loop IVs, loop bound AffineMaps and their /// associated operands for a set of loops within a loop nest (typically the /// set of loops surrounding a store operation). Loop bound AffineMaps which @@ -80,6 +92,12 @@ // Returns failure if we cannot add loop bounds because of unsupported cases. LogicalResult getAsConstraints(FlatAffineConstraints *cst); + /// Adds to 'cst' constraints which represent the original loop bounds on + /// 'ivs' in 'this'. This corresponds to the original domain of the loop nest + /// from which the slice is being computed. Returns failure if we cannot add + /// loop bounds because of unsupported cases. + LogicalResult getSourceAsConstraints(FlatAffineConstraints &cst); + // Clears all bounds and operands in slice state. void clearBounds(); @@ -93,6 +111,22 @@ // information hasn't changed. Optional isMaximal() const; + /// Checks the validity of the slice computed. This is done using the + /// following steps: + /// 1. Get the new domain of the slice that would be created if fusion + /// succeeds. This domain gets constructed with source loop IVS and + /// destination loop IVS as dimensions. + /// 2. Project out the dimensions of the destination loop from the domain + /// above calculated in step(1) to express it purely in terms of the source + /// loop IVs. + /// 3. Calculate a set difference between the iterations of the new domain and + /// the original domain of the source loop. + /// If this difference is empty, the slice is declared to be valid. Otherwise, + /// return false as it implies that the effective fusion results in at least + /// one iteration of the slice that was not originally in the source's domain. + /// If the validity cannot be determined, returns llvm:None. + Optional isSliceValid(); + void dump() const; private: @@ -151,21 +185,21 @@ ComputationSliceState *sliceState); /// Computes in 'sliceUnion' the union of all slice bounds computed at -/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'. -/// The parameter 'numCommonLoops' is the number of loops common to the -/// operations in 'opsA' and 'opsB'. -/// If 'isBackwardSlice' is true, computes slice bounds for loop nest -/// surrounding ops in 'opsA', as a function of IVs and symbols of loop nest -/// surrounding ops in 'opsB' at 'loopDepth'. -/// If 'isBackwardSlice' is false, computes slice bounds for loop nest -/// surrounding ops in 'opsB', as a function of IVs and symbols of loop nest -/// surrounding ops in 'opsA' at 'loopDepth'. -/// Returns 'success' if union was computed, 'failure' otherwise. +/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and +/// then verifies if it is valid. The parameter 'numCommonLoops' is the number +/// of loops common to the operations in 'opsA' and 'opsB'. If 'isBackwardSlice' +/// is true, computes slice bounds for loop nest surrounding ops in 'opsA', as a +/// function of IVs and symbols of loop nest surrounding ops in 'opsB' at +/// 'loopDepth'. If 'isBackwardSlice' is false, computes slice bounds for loop +/// nest surrounding ops in 'opsB', as a function of IVs and symbols of loop +/// nest surrounding ops in 'opsA' at 'loopDepth'. Returns +/// 'SliceComputationResult::Success' if union was computed correctly, an +/// appropriate 'failure' otherwise. // TODO: Change this API to take 'forOpA'/'forOpB'. -LogicalResult computeSliceUnion(ArrayRef opsA, - ArrayRef opsB, unsigned loopDepth, - unsigned numCommonLoops, bool isBackwardSlice, - ComputationSliceState *sliceUnion); +SliceComputationResult +computeSliceUnion(ArrayRef opsA, ArrayRef opsB, + unsigned loopDepth, unsigned numCommonLoops, + bool isBackwardSlice, ComputationSliceState *sliceUnion); /// Creates a clone of the computation contained in the loop nest surrounding /// 'srcOpInst', slices the iteration space of src loop based on slice bounds diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h --- a/mlir/include/mlir/Transforms/LoopFusionUtils.h +++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h @@ -35,6 +35,7 @@ FailBlockDependence, // Fusion would violate another dependence in block. FailFusionDependence, // Fusion would reverse dependences between loops. FailComputationSlice, // Unable to compute src loop computation slice. + FailIncorrectSlice, // Slice is computed, but it is incorrect. } value; FusionResult(ResultEnum v) : value(v) {} }; diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -2128,13 +2128,22 @@ continue; } - if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, - /*lower=*/true))) - return failure(); - - if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false, - /*lower=*/false))) - return failure(); + // If lower or upper bound maps are null or provide no results, it implies + // that the source loop was not at all sliced, and the entire loop will be a + // part of the slice. + if (lbMap && lbMap.getNumResults() != 0 && ubMap && + ubMap.getNumResults() != 0) { + if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, + /*lower=*/true))) + return failure(); + if (failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false, + /*lower=*/false))) + return failure(); + } else { + auto loop = getForInductionVarOwner(values[i]); + if (failed(this->addAffineForOpDomain(loop))) + return failure(); + } } return success(); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -61,6 +61,21 @@ std::reverse(ops->begin(), ops->end()); } +// Populates 'cst' with FlatAffineConstraints which represent original domain of +// the loop bounds that define 'ivs'. +LogicalResult +ComputationSliceState::getSourceAsConstraints(FlatAffineConstraints &cst) { + assert(!ivs.empty() && "Cannot have a slice without its IVs"); + cst.reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs); + for (Value iv : ivs) { + AffineForOp loop = getForInductionVarOwner(iv); + assert(loop && "Expected affine for"); + if (failed(cst.addAffineForOpDomain(loop))) + return failure(); + } + return success(); +} + // Populates 'cst' with FlatAffineConstraints which represent slice bounds. LogicalResult ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { @@ -75,9 +90,10 @@ values.append(lbOperands[0].begin(), lbOperands[0].end()); cst->reset(numDims, numSymbols, 0, values); - // Add loop bound constraints for values which are loop IVs and equality - // constraints for symbols which are constants. - for (const auto &value : values) { + // Add loop bound constraints for values which are loop IVs of the destination + // of fusion and equality constraints for symbols which are constants. + for (unsigned i = numDims, end = values.size(); i < end; ++i) { + Value value = values[i]; assert(cst->containsId(value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. @@ -196,6 +212,76 @@ return true; } +/// Returns true if it is deterministically verified that the original iteration +/// space of the slice is contained within the new iteration space that is +/// created after fusing 'this' slice into its destination. +Optional ComputationSliceState::isSliceValid() { + // Fast check to determine if the slice is valid. If the following conditions + // are verified to be true, slice is declared valid by the fast check: + // 1. Each slice loop is a single iteration loop bound in terms of a single + // destination loop IV. + // 2. Loop bounds of the destination loop IV (from above) and those of the + // source loop IV are exactly the same. + // If the fast check is inconclusive or false, we proceed with a more + // expensive analysis. + // TODO: Store the result of the fast check, as it might be used again in + // `canRemoveSrcNodeAfterFusion`. + Optional isValidFastCheck = isSliceMaximalFastCheck(); + if (isValidFastCheck.hasValue() && isValidFastCheck.getValue()) + return true; + + // Create constraints for the source loop nest using which slice is computed. + FlatAffineConstraints srcConstraints; + // TODO: Store the source's domain to avoid computation at each depth. + if (failed(getSourceAsConstraints(srcConstraints))) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n"); + return llvm::None; + } + // As the set difference utility currently cannot handle symbols in its + // operands, validity of the slice cannot be determined. + if (srcConstraints.getNumSymbolIds() > 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n"); + return llvm::None; + } + // TODO: Handle local ids in the source domains while using the 'projectOut' + // utility below. Currently, aligning is not done assuming that there will be + // no local ids in the source domain. + if (srcConstraints.getNumLocalIds() != 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n"); + return llvm::None; + } + + // Create constraints for the slice loop nest that would be created if the + // fusion succeeds. + FlatAffineConstraints sliceConstraints; + if (failed(getAsConstraints(&sliceConstraints))) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n"); + return llvm::None; + } + + // Projecting out every dimension other than the 'ivs' to express slice's + // domain completely in terms of source's IVs. + sliceConstraints.projectOut(ivs.size(), + sliceConstraints.getNumIds() - ivs.size()); + + LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n"); + LLVM_DEBUG(srcConstraints.dump()); + LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds " + "(expressed in terms of its source's IVs):\n"); + LLVM_DEBUG(sliceConstraints.dump()); + + // TODO: Store 'srcSet' to avoid recalculating for each depth. + PresburgerSet srcSet(srcConstraints); + PresburgerSet sliceSet(sliceConstraints); + PresburgerSet diffSet = sliceSet.subtract(srcSet); + + if (!diffSet.isIntegerEmpty()) { + LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n"); + return false; + } + return true; +} + /// Returns true if the computation slice encloses all the iterations of the /// sliced loop nest. Returns false if it does not. Returns llvm::None if it /// cannot determine if the slice is maximal or not. @@ -715,14 +801,14 @@ } /// Computes in 'sliceUnion' the union of all slice bounds computed at -/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'. -/// Returns 'Success' if union was computed, 'failure' otherwise. -LogicalResult mlir::computeSliceUnion(ArrayRef opsA, - ArrayRef opsB, - unsigned loopDepth, - unsigned numCommonLoops, - bool isBackwardSlice, - ComputationSliceState *sliceUnion) { +/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and +/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if +/// union was computed correctly, an appropriate failure otherwise. +SliceComputationResult +mlir::computeSliceUnion(ArrayRef opsA, ArrayRef opsB, + unsigned loopDepth, unsigned numCommonLoops, + bool isBackwardSlice, + ComputationSliceState *sliceUnion) { // Compute the union of slice bounds between all pairs in 'opsA' and // 'opsB' in 'sliceUnionCst'. FlatAffineConstraints sliceUnionCst; @@ -738,7 +824,7 @@ if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) || (isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) { LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n"); - return failure(); + return SliceComputationResult::GenericFailure; } bool readReadAccesses = isa(srcAccess.opInst) && @@ -751,7 +837,7 @@ /*allowRAR=*/readReadAccesses); if (result.value == DependenceResult::Failure) { LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n"); - return failure(); + return SliceComputationResult::GenericFailure; } if (result.value == DependenceResult::NoDependence) continue; @@ -768,7 +854,7 @@ if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n"); - return failure(); + return SliceComputationResult::GenericFailure; } assert(sliceUnionCst.getNumDimAndSymbolIds() > 0); continue; @@ -779,7 +865,7 @@ if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n"); - return failure(); + return SliceComputationResult::GenericFailure; } // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed. @@ -802,9 +888,9 @@ // to unionBoundingBox below expects constraints for each Loop IV, even // if they are the unsliced full loop bounds added here. if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst))) - return failure(); + return SliceComputationResult::GenericFailure; if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst))) - return failure(); + return SliceComputationResult::GenericFailure; } // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. if (sliceUnionCst.getNumLocalIds() > 0 || @@ -812,14 +898,14 @@ failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute union bounding box of slice bounds\n"); - return failure(); + return SliceComputationResult::GenericFailure; } } } // Empty union. if (sliceUnionCst.getNumDimAndSymbolIds() == 0) - return failure(); + return SliceComputationResult::GenericFailure; // Gather loops surrounding ops from loop nest where slice will be inserted. SmallVector ops; @@ -831,7 +917,7 @@ getInnermostCommonLoopDepth(ops, &surroundingLoops); if (loopDepth > innermostCommonLoopDepth) { LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n"); - return failure(); + return SliceComputationResult::GenericFailure; } // Store 'numSliceLoopIVs' before converting dst loop IVs to dims. @@ -868,7 +954,18 @@ // canonicalization. sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); - return success(); + + // Check if the slice computed is valid. Return success only if it is verified + // that the slice is valid, otherwise return appropriate failure status. + Optional isSliceValid = sliceUnion->isSliceValid(); + if (!isSliceValid.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n"); + return SliceComputationResult::GenericFailure; + } + if (!isSliceValid.getValue()) + return SliceComputationResult::IncorrectSliceFailure; + + return SliceComputationResult::Success; } const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier"; diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -347,12 +347,18 @@ // Compute union of computation slices computed between all pairs of ops // from 'forOpA' and 'forOpB'. - if (failed(mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, - numCommonLoops, isSrcForOpBeforeDstForOp, - srcSlice))) { + SliceComputationResult sliceComputationResult = + mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, numCommonLoops, + isSrcForOpBeforeDstForOp, srcSlice); + if (sliceComputationResult.value == SliceComputationResult::GenericFailure) { LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); return FusionResult::FailPrecondition; } + if (sliceComputationResult.value == + SliceComputationResult::IncorrectSliceFailure) { + LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n"); + return FusionResult::FailIncorrectSlice; + } return FusionResult::Success; } @@ -400,7 +406,7 @@ auto *parentForOp = forOp->getParentOp(); if (!llvm::isa(parentForOp)) { if (!isa(parentForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp"); + LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); return WalkResult::interrupt(); } // Add mapping to 'forOp' from its parent AffineForOp. @@ -421,7 +427,7 @@ Optional maybeConstTripCount = getConstantTripCount(forOp); if (!maybeConstTripCount.hasValue()) { // Currently only constant trip count loop nests are supported. - LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported"); + LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); return WalkResult::interrupt(); } @@ -519,7 +525,11 @@ auto *op = forOp.getOperation(); AffineMap lbMap = slice.lbs[i]; AffineMap ubMap = slice.ubs[i]; - if (lbMap == AffineMap() || ubMap == AffineMap()) { + // If lower or upper bound maps are null or provide no results, it implies + // that source loop was not at all sliced, and the entire loop will be a + // part of the slice. + if (!lbMap || lbMap.getNumResults() == 0 || !ubMap || + ubMap.getNumResults() == 0) { // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) { (*tripCountMap)[op] = diff --git a/mlir/test/Transforms/loop-fusion-slice-computation.mlir b/mlir/test/Transforms/loop-fusion-slice-computation.mlir --- a/mlir/test/Transforms/loop-fusion-slice-computation.mlir +++ b/mlir/test/Transforms/loop-fusion-slice-computation.mlir @@ -7,7 +7,7 @@ %0 = memref.alloc() : memref<100xf32> %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} + // expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} affine.store %cst, %0[%i0] : memref<100xf32> } affine.for %i1 = 0 to 5 { @@ -19,6 +19,23 @@ // ----- +// CHECK-LABEL: func @forward_slice_slice_depth1_loop_nest() { +func @forward_slice_slice_depth1_loop_nest() { + %0 = memref.alloc() : memref<100xf32> + %cst = constant 7.000000e+00 : f32 + affine.for %i0 = 0 to 5 { + // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} + affine.store %cst, %0[%i0] : memref<100xf32> + } + affine.for %i1 = 0 to 16 { + // expected-remark@-1 {{Incorrect slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} + %1 = affine.load %0[%i1] : memref<100xf32> + } + return +} + +// ----- + // Loop %i0 writes to locations [2, 17] and loop %i0 reads from locations [3, 6] // Slice loop bounds should be adjusted such that the load/store are for the // same location. @@ -27,7 +44,7 @@ %0 = memref.alloc() : memref<100xf32> %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0 + 3), (d0) -> (d0 + 4)] )}} + // expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0 + 3), (d0) -> (d0 + 4)] )}} %a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0) affine.store %cst, %0[%a0] : memref<100xf32> } @@ -48,8 +65,8 @@ %0 = memref.alloc() : memref<100x100xf32> %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} - // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} + // expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} + // expected-remark@-2 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} affine.for %i1 = 0 to 16 { affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } @@ -75,8 +92,8 @@ %c0 = constant 0 : index %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} - // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (0), (d0, d1) -> (8)] )}} + // expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} + // expected-remark@-2 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (0), (d0, d1) -> (16)] )}} affine.for %i1 = 0 to 16 { affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } @@ -103,7 +120,7 @@ %c0 = constant 0 : index %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} + // expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} affine.for %i1 = 0 to 16 { affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } @@ -128,8 +145,8 @@ %c0 = constant 0 : index %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}} - // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} + // expected-remark@-1 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}} + // expected-remark@-2 {{Incorrect slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} affine.for %i1 = 0 to 16 { affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -3068,3 +3068,50 @@ // CHECK-LABEL: func @call_op_does_not_prevent_fusion // CHECK: affine.for // CHECK-NOT: affine.for + +// ----- + +// Fusion is avoided when the slice computed is invalid. Comments below describe +// incorrect backward slice computation. Similar logic applies for forward slice +// as well. +func @no_fusion_cannot_compute_valid_slice() { + %A = memref.alloc() : memref<5xf32> + %B = memref.alloc() : memref<6xf32> + %C = memref.alloc() : memref<5xf32> + %cst = constant 0. : f32 + + affine.for %arg0 = 0 to 5 { + %a = affine.load %A[%arg0] : memref<5xf32> + affine.store %a, %B[%arg0 + 1] : memref<6xf32> + } + + affine.for %arg0 = 0 to 5 { + // Backward slice computed will be: + // slice ( src loop: 0, dst loop: 1, depth: 1 : insert point: (1, 0) + // loop bounds: [(d0) -> (d0 - 1), (d0) -> (d0)] ) + + // Resulting fusion would be as below. It is easy to note the out-of-bounds + // access by 'affine.load'. + + // #map0 = affine_map<(d0) -> (d0 - 1)> + // #map1 = affine_map<(d0) -> (d0)> + // affine.for %arg1 = #map0(%arg0) to #map1(%arg0) { + // %5 = affine.load %1[%arg1] : memref<5xf32> + // ... + // ... + // } + + %a = affine.load %B[%arg0] : memref<6xf32> + %b = mulf %a, %cst : f32 + affine.store %b, %C[%arg0] : memref<5xf32> + } + return +} +// CHECK-LABEL: func @no_fusion_cannot_compute_valid_slice +// CHECK: affine.for +// CHECK-NEXT: affine.load +// CHECK-NEXT: affine.store +// CHECK: affine.for +// CHECK-NEXT: affine.load +// CHECK-NEXT: mulf +// CHECK-NEXT: affine.store diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -99,10 +99,11 @@ return os.str(); } -// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths -// in range ['loopDepth' + 1, 'maxLoopDepth']. -// Emits a string representation of the slice union as a remark on 'loops[j]'. -// Returns false as IR is not transformed. +/// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths +/// in range ['loopDepth' + 1, 'maxLoopDepth']. +/// Emits a string representation of the slice union as a remark on 'loops[j]' +/// and marks this as incorrect slice if the slice is invalid. Returns false as +/// IR is not transformed. static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB, unsigned i, unsigned j, unsigned loopDepth, unsigned maxLoopDepth) { @@ -113,6 +114,10 @@ forOpB->emitRemark("slice (") << " src loop: " << i << ", dst loop: " << j << ", depth: " << d << " : " << getSliceStr(sliceUnion) << ")"; + } else if (result.value == FusionResult::FailIncorrectSlice) { + forOpB->emitRemark("Incorrect slice (") + << " src loop: " << i << ", dst loop: " << j << ", depth: " << d + << " : " << getSliceStr(sliceUnion) << ")"; } } return false;