diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -80,6 +80,12 @@ // Returns failure if we cannot add loop bounds because of unsupported cases. LogicalResult getAsConstraints(FlatAffineConstraints *cst); + // Adds to 'cst' with constraints which represent the original loop bounds on + // 'ivs' in 'this'. This corresponds to the original domain of the loop nest + // of which the slice is being computed. Returns failure if we cannot add loop + // bounds because of unsupported cases. + LogicalResult getSourceAsConstraints(FlatAffineConstraints *cst); + // Clears all bounds and operands in slice state. void clearBounds(); @@ -93,6 +99,17 @@ // information hasn't changed. Optional isMaximal() const; + // Checks the validity of the slice computed. This is done by constructing the + // new domain of the slice that would be created if fusion succeeds, and then + // projecting out the dimensions of the destination loop from it to express it + // only in terms of the source loop IVs. Then, a set difference between the + // iterations of the new domain and the original domain of the source loop is + // considered. If this difference is empty, slice is declared to be valid. + // Otherwise, return false as it implies that the effective fusion results in + // at least one iteration of the slice that was not originally in the source's + // domain. If the validity cannot be determined, returns llvm:None. + Optional isSliceValid(); + void dump() const; private: diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -2128,13 +2128,22 @@ continue; } - if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, - /*lower=*/true))) - return failure(); - - if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false, - /*lower=*/false))) - return failure(); + // If lower or upper bound maps are null or provide no results, it implies + // that source loop was not at all sliced, and the entire loop will be a + // part of the slice. + if (lbMap && lbMap.getNumResults() != 0 && ubMap && + ubMap.getNumResults() != 0) { + if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, + /*lower=*/true))) + return failure(); + if (failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false, + /*lower=*/false))) + return failure(); + } else { + auto loop = getForInductionVarOwner(values[i]); + if (failed(this->addAffineForOpDomain(loop))) + return failure(); + } } return success(); } diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -61,6 +61,21 @@ std::reverse(ops->begin(), ops->end()); } +// Populates 'cst' with FlatAffineConstraints which represent original domain of +// the loop bounds that define 'ivs'. +LogicalResult +ComputationSliceState::getSourceAsConstraints(FlatAffineConstraints *cst) { + assert(!ivs.empty()); + cst->reset(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs); + for (Value iv : ivs) { + AffineForOp loop = getForInductionVarOwner(iv); + assert(loop && "Expected affine for"); + if (failed(cst->addAffineForOpDomain(loop))) + return failure(); + } + return success(); +} + // Populates 'cst' with FlatAffineConstraints which represent slice bounds. LogicalResult ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { @@ -75,9 +90,10 @@ values.append(lbOperands[0].begin(), lbOperands[0].end()); cst->reset(numDims, numSymbols, 0, values); - // Add loop bound constraints for values which are loop IVs and equality - // constraints for symbols which are constants. - for (const auto &value : values) { + // Add loop bound constraints for values which are loop IVs of the destination + // of fusion and equality constraints for symbols which are constants. + for (unsigned i = numDims, end = values.size(); i < end; ++i) { + Value value = values[i]; assert(cst->containsId(value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. @@ -196,6 +212,59 @@ return true; } +// Returns true if it is determinastically verified that the original iteration +// space of the slice is contained within the new iteration space that is +// created after fusing 'this' slice into its destination. +Optional ComputationSliceState::isSliceValid() { + // Create constraints for the source loop nest using which slice is computed. + FlatAffineConstraints srcConstraints; + // TODO: Store the source's domain to avoid computation at each depth. + if (failed(getSourceAsConstraints(&srcConstraints))) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n"); + return llvm::None; + } + // As the set difference utility currently cannot handle symbols in its + // operands, validity of the slice cannot be determined. + if (srcConstraints.getNumSymbolIds() != 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n"); + return llvm::None; + } + // TODO: Handle local ids in the source domains while using the 'projectOut' + // utility below. Currently, aligning is not done assuming that there will be + // no local ids in the source domain. + if (srcConstraints.getNumLocalIds() != 0) { + LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n"); + return llvm::None; + } + + // Create constraints for the slice loop nest that would be created if the + // fusion succeeds. + FlatAffineConstraints sliceConstraints; + if (failed(getAsConstraints(&sliceConstraints))) { + LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n"); + return llvm::None; + } + + // Projecting out every dimension other than the 'ivs' to express slice's + // domain completely in terms of source's IVs. + sliceConstraints.projectOut(ivs.size(), + sliceConstraints.getNumIds() - ivs.size()); + + LLVM_DEBUG(srcConstraints.dump()); + LLVM_DEBUG(sliceConstraints.dump()); + + // TODO: Store 'srcSet' to avoid recalculating for each depth. + PresburgerSet srcSet(srcConstraints); + PresburgerSet sliceSet(sliceConstraints); + PresburgerSet diffSet = sliceSet.subtract(srcSet); + + if (!diffSet.isIntegerEmpty()) { + LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n"); + return false; + } + return true; +} + /// Returns true if the computation slice encloses all the iterations of the /// sliced loop nest. Returns false if it does not. Returns llvm::None if it /// cannot determine if the slice is maximal or not. @@ -868,6 +937,13 @@ // canonicalization. sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); + + Optional isSliceValid = sliceUnion->isSliceValid(); + if (!isSliceValid.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n"); + } else if (!isSliceValid.getValue()) + return failure(); + return success(); } diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -400,7 +400,7 @@ auto *parentForOp = forOp->getParentOp(); if (!llvm::isa(parentForOp)) { if (!isa(parentForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp"); + LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); return WalkResult::interrupt(); } // Add mapping to 'forOp' from its parent AffineForOp. @@ -421,7 +421,7 @@ Optional maybeConstTripCount = getConstantTripCount(forOp); if (!maybeConstTripCount.hasValue()) { // Currently only constant trip count loop nests are supported. - LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported"); + LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); return WalkResult::interrupt(); } @@ -519,7 +519,11 @@ auto *op = forOp.getOperation(); AffineMap lbMap = slice.lbs[i]; AffineMap ubMap = slice.ubs[i]; - if (lbMap == AffineMap() || ubMap == AffineMap()) { + // If lower or upper bound maps are null or provide no results, it implies + // that source loop was not at all sliced, and the entire loop will be a + // part of the slice. + if (lbMap == AffineMap() || ubMap == AffineMap() || + lbMap.getNumResults() == 0 || ubMap.getNumResults() == 0) { // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) { (*tripCountMap)[op] = diff --git a/mlir/test/Transforms/loop-fusion-slice-computation.mlir b/mlir/test/Transforms/loop-fusion-slice-computation.mlir --- a/mlir/test/Transforms/loop-fusion-slice-computation.mlir +++ b/mlir/test/Transforms/loop-fusion-slice-computation.mlir @@ -7,7 +7,6 @@ %0 = alloc() : memref<100xf32> %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} affine.store %cst, %0[%i0] : memref<100xf32> } affine.for %i1 = 0 to 5 { @@ -19,6 +18,22 @@ // ----- +// CHECK-LABEL: func @forward_slice_slice_depth1_loop_nest() { +func @forward_slice_slice_depth1_loop_nest() { + %0 = alloc() : memref<100xf32> + %cst = constant 7.000000e+00 : f32 + affine.for %i0 = 0 to 5 { + // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] )}} + affine.store %cst, %0[%i0] : memref<100xf32> + } + affine.for %i1 = 0 to 16 { + %1 = affine.load %0[%i1] : memref<100xf32> + } + return +} + +// ----- + // Loop %i0 writes to locations [2, 17] and loop %i0 reads from locations [3, 6] // Slice loop bounds should be adjusted such that the load/store are for the // same location. @@ -27,7 +42,6 @@ %0 = alloc() : memref<100xf32> %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0 + 3), (d0) -> (d0 + 4)] )}} %a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0) affine.store %cst, %0[%a0] : memref<100xf32> } @@ -48,8 +62,6 @@ %0 = alloc() : memref<100x100xf32> %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} - // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} affine.for %i1 = 0 to 16 { affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } @@ -75,8 +87,6 @@ %c0 = constant 0 : index %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} - // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (0), (d0, d1) -> (8)] )}} affine.for %i1 = 0 to 16 { affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } @@ -103,7 +113,6 @@ %c0 = constant 0 : index %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 2) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (8)] )}} affine.for %i1 = 0 to 16 { affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } @@ -128,8 +137,6 @@ %c0 = constant 0 : index %cst = constant 7.000000e+00 : f32 affine.for %i0 = 0 to 16 { - // expected-remark@-1 {{slice ( src loop: 1, dst loop: 0, depth: 1 : insert point: (1, 1) loop bounds: [(d0) -> (d0), (d0) -> (d0 + 1)] [(d0) -> (0), (d0) -> (10)] )}} - // expected-remark@-2 {{slice ( src loop: 1, dst loop: 0, depth: 2 : insert point: (2, 1) loop bounds: [(d0, d1) -> (d0), (d0, d1) -> (d0 + 1)] [(d0, d1) -> (d1), (d0, d1) -> (d1 + 1)] )}} affine.for %i1 = 0 to 16 { affine.store %cst, %0[%i0, %i1] : memref<100x100xf32> } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -3068,3 +3068,30 @@ // CHECK-LABEL: func @call_op_does_not_prevent_fusion // CHECK: affine.for // CHECK-NOT: affine.for + +// ----- + +func @no_fusion_cannot_compute_valid_slice() { + %A = alloc() : memref<5xf32> + %B = alloc() : memref<6xf32> + %C = alloc() : memref<5xf32> + %cst = constant 0. : f32 + affine.for %arg0 = 0 to 5 { + %a = affine.load %A[%arg0] : memref<5xf32> + affine.store %a, %B[%arg0 + 1] : memref<6xf32> + } + affine.for %arg0 = 0 to 5 { + %a = affine.load %B[%arg0] : memref<6xf32> + %b = mulf %a, %cst : f32 + affine.store %b, %C[%arg0] : memref<5xf32> + } + return +} +// CHECK-LABEL: func @no_fusion_cannot_compute_valid_slice +// CHECK: affine.for +// CHECK-NEXT: affine.load +// CHECK-NEXT: affine.store +// CHECK: affine.for +// CHECK-NEXT: affine.load +// CHECK-NEXT: mulf +// CHECK-NEXT: affine.store