diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -90,6 +90,20 @@ // TileUsingSCFForOp pattern implementation. //===----------------------------------------------------------------------===// +// Check if `stride` evenly divides the trip count `size - offset`. +static bool tileDividesIterationDomain(Range loopRange) { + Optional offsetAsInt = getConstantIntValue(loopRange.offset); + if (!offsetAsInt) + return false; + Optional sizeAsInt = getConstantIntValue(loopRange.size); + if (!sizeAsInt) + return false; + Optional strideAsInt = getConstantIntValue(loopRange.stride); + if (!strideAsInt) + return false; + return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); +} + /// Generate an empty loop nest that represents the tiled loop nest shell. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. @@ -134,9 +148,15 @@ loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{}, [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, ValueRange /*iterArgs*/) { - Value boundedTileSize = builder.create( - bodyLoc, minMap, - ValueRange{iv, tileSizeVals[loopRange.index()], size}); + bool canAvoidMap = tileDividesIterationDomain( + Range{loopRange.value().offset, loopRange.value().size, + tileSizeVals[loopRange.index()]}); + Value boundedTileSize = + (canAvoidMap) + ? tileSizeVals[loopRange.index()] + : builder.create( + bodyLoc, minMap, + ValueRange{iv, tileSizeVals[loopRange.index()], size}); sizes[loopRange.index()] = boundedTileSize; builder.create(loc); }); diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -101,7 +101,6 @@ return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32> } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> // CHECK: func.func @multi_result( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index @@ -116,20 +115,19 @@ // CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]] // CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]] // CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]]) -// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C300]]] // CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]] -// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, %[[TS_X]]] [1, 1, 1] +// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1] // CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]] -// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1] +// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1] // CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]] -// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1] +// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1] // CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic // CHECK-SAME: ins(%[[ARG_TILE]] : // CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] : // CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]] -// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1] +// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1] // CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]] -// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1] +// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1] // CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]] // CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1 // CHECK: return %[[OUTER]]#0, %[[OUTER]]#1