diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -857,20 +857,20 @@ AffineForOp forOp, unsigned maxLoopDepth, std::vector> *depCompsVec) { // Collect all load and store ops in loop nest rooted at 'forOp'. - SmallVector loadAndStoreOpInsts; - forOp.getOperation()->walk([&](Operation *opInst) { - if (isa(opInst)) - loadAndStoreOpInsts.push_back(opInst); + SmallVector loadAndStoreOps; + forOp.getOperation()->walk([&](Operation *op) { + if (isa(op)) + loadAndStoreOps.push_back(op); }); - unsigned numOps = loadAndStoreOpInsts.size(); + unsigned numOps = loadAndStoreOps.size(); for (unsigned d = 1; d <= maxLoopDepth; ++d) { for (unsigned i = 0; i < numOps; ++i) { - auto *srcOpInst = loadAndStoreOpInsts[i]; - MemRefAccess srcAccess(srcOpInst); + auto *srcOp = loadAndStoreOps[i]; + MemRefAccess srcAccess(srcOp); for (unsigned j = 0; j < numOps; ++j) { - auto *dstOpInst = loadAndStoreOpInsts[j]; - MemRefAccess dstAccess(dstOpInst); + auto *dstOp = loadAndStoreOps[j]; + MemRefAccess dstAccess(dstOp); FlatAffineConstraints dependenceConstraints; SmallVector depComps; diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp @@ -157,6 +157,75 @@ } } +/// This function checks whether hyper-rectangular loop tiling of the nest +/// represented by `origLoops` is valid. The validity condition is from Irigoin +/// and Triolet, which states that two tiles cannot depend on each other. We +/// simplify such condition to just checking whether there is any negative +/// dependence direction, since we have the prior knowledge that the tiling +/// results will be hyper-rectangles, which are scheduled in the +/// lexicographically increasing order on the vector of loop indices. This +/// function will return failure when any dependence component is negative along +/// any of `origLoops`. +static LogicalResult +checkTilingLegality(MutableArrayRef origLoops, + ArrayRef tileSizes) { + assert(!origLoops.empty() && "no original loops provided"); + + // We first find out all dependences we intend to check. + SmallVector loadAndStoreOps; + origLoops[0].getOperation()->walk([&](Operation *op) { + if (isa(op)) + loadAndStoreOps.push_back(op); + }); + + unsigned numOps = loadAndStoreOps.size(); + unsigned numLoops = origLoops.size(); + FlatAffineConstraints dependenceConstraints; + for (unsigned d = 1; d <= numLoops + 1; ++d) { + for (unsigned i = 0; i < numOps; ++i) { + Operation *srcOp = loadAndStoreOps[i]; + MemRefAccess srcAccess(srcOp); + for (unsigned j = 0; j < numOps; ++j) { + Operation *dstOp = loadAndStoreOps[j]; + MemRefAccess dstAccess(dstOp); + + SmallVector depComps; + dependenceConstraints.reset(); + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, &depComps); + + // Skip if there is no dependence in this case. + if (!hasDependence(result)) + continue; + + // Check whether there is any negative direction vector in the + // dependence components found above, which means that dependence is + // violated by the default hyper-rect tiling method. + LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated " + "for dependence at depth: " + << Twine(d) << " between:\n";); + LLVM_DEBUG(srcAccess.opInst->dump();); + LLVM_DEBUG(dstAccess.opInst->dump();); + for (unsigned k = 0, e = depComps.size(); k < e; k++) { + DependenceComponent depComp = depComps[k]; + if (depComp.lb.hasValue() && depComp.ub.hasValue() && + depComp.lb.getValue() < depComp.ub.getValue() && + depComp.ub.getValue() < 0) { + LLVM_DEBUG(llvm::dbgs() + << "Dependence component lb = " + << Twine(depComp.lb.getValue()) + << " ub = " << Twine(depComp.ub.getValue()) + << " is negative at depth: " << Twine(d) + << " and thus violates the legality rule.\n"); + return failure(); + } + } + } + } + } + + return success(); +} /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO: handle non hyper-rectangular spaces. @@ -172,6 +241,10 @@ auto origLoops = input; + // Perform tiling legality test. + if (failed(checkTilingLegality(origLoops, tileSizes))) + origLoops[0].emitRemark("tiled code is illegal due to dependences"); + AffineForOp rootAffineForOp = origLoops[0]; auto loc = rootAffineForOp.getLoc(); // Note that width is at least one since band isn't empty. diff --git a/mlir/test/Dialect/Affine/loop-tiling-validity.mlir b/mlir/test/Dialect/Affine/loop-tiling-validity.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/loop-tiling-validity.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt %s -split-input-file -affine-loop-tile="tile-size=32" -verify-diagnostics | FileCheck %s + +// ----- + +// There is no dependence violated in this case. No error should be raised. + +// CHECK-DAG: [[$LB:#map[0-9]+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: [[$UB:#map[0-9]+]] = affine_map<(d0) -> (d0 + 32)> +// CHECK-DAG: [[$ID:#map[0-9]+]] = affine_map<() -> (0)> +// CHECK-DAG: [[$ID_PLUS_21:#map[0-9]+]] = affine_map<() -> (64)> + +// CHECK-LABEL: func @legal_loop() +func @legal_loop() { + %0 = alloc() : memref<64xf32> + + affine.for %i = 0 to 64 { + %1 = affine.load %0[%i] : memref<64xf32> + %2 = addf %1, %1 : f32 + affine.store %2, %0[%i] : memref<64xf32> + } + + return +} + +// CHECK: affine.for %{{.*}} = 0 to 64 step 32 { +// CHECK-NEXT: affine.for %{{.*}} = [[$LB]](%{{.*}}) to [[$UB]](%{{.*}}) { + +// ----- + +// There are dependences along the diagonal of the 2d iteration space, +// specifically, they are of direction (+, -). +// The default tiling method (hyper-rect) will violate tiling legality. +// We expect a remark that points that issue out to be emitted. + +// CHECK-LABEL: func @illegal_loop_with_diag_dependence +func @illegal_loop_with_diag_dependence() { + %A = alloc() : memref<64x64xf32> + + affine.for %i = 0 to 64 { + // expected-remark@above {{tiled code is illegal due to dependences}} + affine.for %j = 0 to 64 { + %0 = affine.load %A[%j, %i] : memref<64x64xf32> + %1 = affine.load %A[%i, %j - 1] : memref<64x64xf32> + %2 = addf %0, %1 : f32 + affine.store %2, %A[%i, %j] : memref<64x64xf32> + } + } + + return +}