diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp @@ -157,6 +157,64 @@ } } +LogicalResult checkTilingLegality(MutableArrayRef origLoops, + ArrayRef tileSizes) { + // We assume that all the forOp's in origLoops are tiled. + assert(!origLoops.empty() && "no original loops provided"); + + // We first find out all dependences we intend to check. + SmallVector loadAndStoreOpInsts; + origLoops[0].getOperation()->walk([&](Operation *opInst) { + if (isa(opInst)) + loadAndStoreOpInsts.push_back(opInst); + }); + + unsigned numOps = loadAndStoreOpInsts.size(); + for (unsigned d = 1; d <= origLoops.size() + 1; ++d) { + for (unsigned i = 0; i < numOps; ++i) { + Operation *srcOpInst = loadAndStoreOpInsts[i]; + MemRefAccess srcAccess(srcOpInst); + for (unsigned j = 0; j < numOps; ++j) { + Operation *dstOpInst = loadAndStoreOpInsts[j]; + MemRefAccess dstAccess(dstOpInst); + + FlatAffineConstraints dependenceConstraints; + SmallVector depComps; + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, &depComps); + + // Skip if there is no dependence in this case. + if (!hasDependence(result)) + continue; + + // Check whether there is any negative direction vector in the + // dependence components found above, which means that dependence is + // violated by the default hyper-rect tiling method. + LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated " + "for dependence at depth: " + << Twine(d) << " between:\n";); + LLVM_DEBUG(srcAccess.opInst->dump();); + LLVM_DEBUG(dstAccess.opInst->dump();); + for (unsigned k = 0; k < depComps.size(); k++) { + DependenceComponent depComp = depComps[k]; + + if (depComp.lb.hasValue() && depComp.ub.hasValue() && + depComp.lb.getValue() < depComp.ub.getValue() && + depComp.ub.getValue() < 0) { + LLVM_DEBUG( + llvm::dbgs() + << "Dependence component lb = " << Twine(depComp.lb.getValue()) + << " ub = " << Twine(depComp.ub.getValue()) + << " is negative and thus violates the legality rule.\n"); + return failure(); + } + } + } + } + } + + return success(); +} /// Tiles the specified band of perfectly nested loops creating tile-space loops /// and intra-tile loops. A band is a contiguous set of loops. // TODO: handle non hyper-rectangular spaces. @@ -172,6 +230,11 @@ auto origLoops = input; + // Perform tiling legality test. + if (failed(checkTilingLegality(origLoops, tileSizes))) { + origLoops[0].emitRemark("tiled code is illegal due to dependences"); + } + AffineForOp rootAffineForOp = origLoops[0]; auto loc = rootAffineForOp.getLoc(); // Note that width is at least one since band isn't empty. diff --git a/mlir/test/Dialect/Affine/loop-tiling-validity.mlir b/mlir/test/Dialect/Affine/loop-tiling-validity.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/loop-tiling-validity.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt %s -split-input-file -affine-loop-tile="tile-size=32" -verify-diagnostics | FileCheck %s + +// ----- + +// There is no dependence violated in this case. No error should be raised. + +// CHECK-DAG: [[$LB:#map[0-9]+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: [[$UB:#map[0-9]+]] = affine_map<(d0) -> (d0 + 32)> +// CHECK-DAG: [[$ID:#map[0-9]+]] = affine_map<() -> (0)> +// CHECK-DAG: [[$ID_PLUS_21:#map[0-9]+]] = affine_map<() -> (64)> + +// CHECK-LABEL: func @legal_loop() +func @legal_loop() { + %0 = alloc() : memref<64xf32> + + affine.for %i = 0 to 64 { + %1 = affine.load %0[%i] : memref<64xf32> + %2 = addf %1, %1 : f32 + affine.store %2, %0[%i] : memref<64xf32> + } + + return +} + +// CHECK: affine.for %{{.*}} = 0 to 64 step 32 { +// CHECK-NEXT: affine.for %{{.*}} = [[$LB]](%{{.*}}) to [[$UB]](%{{.*}}) { + +// ----- + +// There are dependences along the diagonal of the 2d iteration space. +// The default tiling method (hyper-rect) will violate tiling legality. +// We expect a remark that points that issue out to be emitted. + +// CHECK-LABEL: func @illegal_loop_with_diag_dependence +func @illegal_loop_with_diag_dependence() { + %A = alloc() : memref<64x64xf32> + + affine.for %i = 0 to 64 { + // expected-remark@above {{tiled code is illegal due to dependences}} + affine.for %j = 0 to 64 { + %0 = affine.load %A[%j, %i] : memref<64x64xf32> + %1 = affine.load %A[%i, %j - 1] : memref<64x64xf32> + %2 = addf %0, %1 : f32 + affine.store %2, %A[%i, %j] : memref<64x64xf32> + } + } + + return +}