diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -205,21 +205,39 @@ // Create procInfo so it dominates loops, if appropriate. OpBuilder &builder = edsc::ScopedContext::getBuilderRef(); Location loc = edsc::ScopedContext::getLocation(); - SmallVector procInfo; - if (distributionOptions.hasValue()) - procInfo = distributionOptions->procInfo(builder, loc, loopRanges); + + SmallVector procInfo; + SmallVector distributionMethod; + if (distributionOptions.hasValue()) { + // Collect loop ranges for parallel dimensions. + SmallVector parallelLoopRanges; + for (auto iteratorType : enumerate(iteratorTypes)) + if (isParallelIteratorType(iteratorType.value())) + parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); + + // Get their distribution schemes. + distributionMethod = distributionOptions->distributionMethod; + if (distributionMethod.size() < parallelLoopRanges.size()) + parallelLoopRanges.resize(distributionMethod.size()); + procInfo = distributionOptions->procInfo(builder, loc, parallelLoopRanges); + } SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); LoopNest loopNest = edsc::loopNestBuilder(lbs, ubs, steps, iterArgInitValues, bodyBuilderFn); - if (!distributionOptions.hasValue() || loopNest.loops.empty()) + if (!distributionOptions || loopNest.loops.empty()) return; - // Only supports cyclic distribution for now. - for (auto it : llvm::zip(loopNest.loops, procInfo, - distributionOptions->distributionMethod)) + // Filter out scf.for loops that were created out of parallel dimensions. + SmallVector loops; + for (auto iteratorType : enumerate(iteratorTypes)) + if (isParallelIteratorType(iteratorType.value())) + loops.push_back(loopNest.loops[iteratorType.index()]); + + // Distribute - only supports cyclic distribution for now. + for (auto it : llvm::zip(loops, procInfo, distributionMethod)) if (std::get<2>(it) == DistributionMethod::Cyclic) mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId, std::get<1>(it).nprocs); diff --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir --- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir @@ -12,8 +12,8 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref -// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} -// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} // CHECK: scf.for %[[ARG3:.*]] = // CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]] // CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]] @@ -70,10 +70,10 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref -// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} -// CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"} -// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} -// CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"} // CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]] // CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]] // CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]] @@ -99,8 +99,8 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref -// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} -// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} // CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]] // CHECK: %[[INBOUNDS:.*]] = cmpi slt, %[[LBX]], %{{.*}} // CHECK: scf.if %[[INBOUNDS]] @@ -128,9 +128,9 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref -// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} -// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} -// CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"} // CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]] // CHECK: %[[LBX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]] // CHECK: %[[STEPX:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSX]]] @@ -159,9 +159,9 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref -// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} -// CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"} -// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} // CHECK: %[[LBY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]] // CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]] // CHECK: scf.parallel (%[[ARG3:.*]]) = (%[[LBY]]) to (%{{.*}}) step (%[[STEPY]]) @@ -186,10 +186,10 @@ -> tensor { // CHECK-DAG: %[[C8:.*]] = constant 8 : index // CHECK-DAG: %[[C0:.*]] = constant 0 : index -// CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} -// CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"} -// CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} -// CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK-DAG: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"} // CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]] // CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]] // CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]] diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -333,12 +333,15 @@ template static SmallVector getGpuProcIds(OpBuilder &b, Location loc, ArrayRef parallelLoopRanges) { + size_t count = std::min(3, parallelLoopRanges.size()); + SmallVector procInfo(count); + const char *xyz[] = {"x", "y", "z"}; Type indexType = b.getIndexType(); - SmallVector procInfo(2); - procInfo[0] = {b.create(loc, indexType, b.getStringAttr("y")), - b.create(loc, indexType, b.getStringAttr("y"))}; - procInfo[1] = {b.create(loc, indexType, b.getStringAttr("x")), - b.create(loc, indexType, b.getStringAttr("x"))}; + for (unsigned i = 0; i < count; ++i) { + procInfo[count - 1 - i] = { + b.create(loc, indexType, b.getStringAttr(xyz[i])), + b.create(loc, indexType, b.getStringAttr(xyz[i]))}; + } return procInfo; }