diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp @@ -177,13 +177,12 @@ // TODO(bondhugula): handle non hyper-rectangular spaces. LogicalResult mlir::tileCodeGen(MutableArrayRef band, ArrayRef tileSizes) { - assert(!band.empty()); - assert(band.size() == tileSizes.size() && "Incorrect number of tile sizes"); + assert(!band.empty() && "no loops in band"); + assert(band.size() == tileSizes.size() && "Too few/many tile sizes"); // Check if the supplied for op's are all successively nested. - for (unsigned i = 1, e = band.size(); i < e; i++) { - assert(band[i].getParentOp() == band[i - 1].getOperation()); - } + for (unsigned i = 1, e = band.size(); i < e; i++) + assert(band[i].getParentOp() == band[i - 1] && "not a perfect nest / band"); auto origLoops = band; @@ -192,11 +191,11 @@ // Note that width is at least one since band isn't empty. unsigned width = band.size(); - SmallVector newLoops(2 * width); - AffineForOp innermostPointLoop; + SmallVector tiledLoops(2 * width); // The outermost among the loops as we add more.. auto *topLoop = rootAffineForOp.getOperation(); + AffineForOp innermostPointLoop; // Add intra-tile (or point) loops. for (unsigned i = 0; i < width; i++) { @@ -206,7 +205,7 @@ pointLoop.getBody()->getOperations().splice( pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); - newLoops[2 * width - 1 - i] = pointLoop; + tiledLoops[2 * width - 1 - i] = pointLoop; topLoop = pointLoop.getOperation(); if (i == 0) innermostPointLoop = pointLoop; @@ -220,7 +219,7 @@ tileSpaceLoop.getBody()->getOperations().splice( tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); - newLoops[2 * width - i - 1] = tileSpaceLoop; + tiledLoops[2 * width - i - 1] = tileSpaceLoop; topLoop = tileSpaceLoop.getOperation(); } @@ -234,16 +233,17 @@ getIndexSet(band, &cst); if (!cst.isHyperRectangular(0, width)) { - rootAffineForOp.emitError("tiled code generation unimplemented for the " - "non-hyperrectangular case"); + llvm::dbgs() << "tiled code generation unimplemented for the " + "non-hyperrectangular case, op:" + << *rootAffineForOp << "\n"; return failure(); } - constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes); - // In this case, the point loop IVs just replace the original ones. - for (unsigned i = 0; i < width; i++) { - origLoopIVs[i].replaceAllUsesWith(newLoops[i + width].getInductionVar()); - } + constructTiledIndexSetHyperRect(origLoops, tiledLoops, tileSizes); + + // Replace original IVs with intra-tile loop IVs. + for (unsigned i = 0; i < width; i++) + origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar()); // Erase the old loop nest. rootAffineForOp.erase(); @@ -381,6 +381,7 @@ std::vector> bands; getTileableBands(getFunction(), &bands); + // Tile each band. for (auto &band : bands) { // Set up tile sizes; fill missing tile sizes at the end with default tile // size or clTileSize if one was provided. @@ -389,7 +390,7 @@ if (llvm::DebugFlag) { auto diag = band[0].emitRemark("using tile sizes ["); for (auto tSize : tileSizes) - diag << tSize << " "; + diag << tSize << ' '; diag << "]\n"; } if (failed(tileCodeGen(band, tileSizes)))