diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -516,6 +516,47 @@ return loops; } +/// Replace the index operations in the body of the loop nest by the matching +/// induction variables. If available use the interchange vector to map the +/// interchanged induction variables to the dimension of the index operation. +static void replaceIndexOpsByInductionVariables( + LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef loopOps, + ArrayRef interchangeVector) { + // Extract the induction variables of the loop nest from outer to inner. + SmallVector allIvs; + for (Operation *loopOp : loopOps) { + llvm::TypeSwitch(loopOp) + .Case([&](scf::ParallelOp parallelOp) { + allIvs.append(parallelOp.getInductionVars().begin(), + parallelOp.getInductionVars().end()); + }) + .Case([&](scf::ForOp forOp) { + allIvs.push_back(forOp.getInductionVar()); + }) + .Case([&](AffineForOp affineForOp) { + allIvs.push_back(affineForOp.getInductionVar()); + }) + .Default([&](Operation *op) { assert(false && "unexpected op"); }); + } + assert(linalgOp.getNumLoops() == allIvs.size() && + "expected the number of loops and induction variables to match"); + // Replace the index operations in the body of the innermost loop op. + if (!loopOps.empty()) { + LoopLikeOpInterface loopOp = loopOps.back(); + for (IndexOp indexOp : + llvm::make_early_inc_range(loopOp.getLoopBody().getOps())) { + // Search the indexing dimension in the interchange vector if available. + assert(interchangeVector.empty() || + interchangeVector.size() == linalgOp.getNumLoops()); + const auto *it = llvm::find(interchangeVector, indexOp.dim()); + uint64_t dim = it != interchangeVector.end() + ? std::distance(interchangeVector.begin(), it) + : indexOp.dim(); + rewriter.replaceOp(indexOp, allIvs[dim]); + } + } +} + namespace { template class LinalgRewritePattern : public RewritePattern { @@ -528,11 +569,14 @@ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto linalgOp = dyn_cast(op); - // TODO: remove hasIndexSemantics check once index ops are supported. - if (!linalgOp || linalgOp.hasIndexSemantics()) + if (!isa(op)) return failure(); - if (!linalgOpToLoopsImpl(op, rewriter, interchangeVector)) + Optional loopOps = + linalgOpToLoopsImpl(op, rewriter, interchangeVector); + if (!loopOps.hasValue()) return failure(); + replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue(), + interchangeVector); rewriter.eraseOp(op); return success(); } diff --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir --- a/mlir/test/Dialect/Linalg/loop-order.mlir +++ b/mlir/test/Dialect/Linalg/loop-order.mlir @@ -24,22 +24,49 @@ // ----- -func @index_op(%arg0: memref<4x8xindex>) { - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - outs(%arg0 : memref<4x8xindex>) { - ^bb0(%arg1: index): // no predecessors - %0 = linalg.index 1 : index - linalg.yield %0 : index +#map = affine_map<(i, j, k, l, m) -> (i, j, k, l, m)> +func @generic(%output: memref<1x2x3x4x5xindex>) { + linalg.generic {indexing_maps = [#map], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} + outs(%output : memref<1x2x3x4x5xindex>) { + ^bb0(%arg0 : index): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %k = linalg.index 2 : index + %l = linalg.index 3 : index + %m = linalg.index 4 : index + %0 = addi %i, %j : index + %1 = addi %0, %k : index + %2 = addi %1, %l : index + %3 = addi %2, %m : index + linalg.yield %3: index } return } -// LOOP-LABEL: @index_op -// LOOP: linalg.generic -// PARALLEL-LABEL: @index_op -// PARALLEL: linalg.generic +// LOOP: scf.for %[[m:.*]] = %c0 to %c5 step %c1 +// LOOP: scf.for %[[i:.*]] = %c0 to %c1 step %c1 +// LOOP: scf.for %[[l:.*]] = %c0 to %c4 step %c1 +// LOOP: scf.for %[[j:.*]] = %c0 to %c2 step %c1 +// LOOP: scf.for %[[k:.*]] = %c0 to %c3 step %c1 +// LOOP: %{{.*}} = addi %[[i]], %[[j]] : index +// LOOP: %{{.*}} = addi %{{.*}}, %[[k]] : index +// LOOP: %{{.*}} = addi %{{.*}}, %[[l]] : index +// LOOP: %{{.*}} = addi %{{.*}}, %[[m]] : index -// AFFINE-LABEL: @index_op -// AFFINE: linalg.generic +// PARALLEL: scf.parallel (%[[m:.*]], %[[i:.*]], %[[l:.*]], %[[j:.*]], %[[k:.*]]) = +// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3) +// PARALLEL: %{{.*}} = addi %[[i]], %[[j]] : index +// PARALLEL: %{{.*}} = addi %{{.*}}, %[[k]] : index +// PARALLEL: %{{.*}} = addi %{{.*}}, %[[l]] : index +// PARALLEL: %{{.*}} = addi %{{.*}}, %[[m]] : index + +// AFFINE: affine.for %[[m:.*]] = 0 to 5 +// AFFINE: affine.for %[[i:.*]] = 0 to 1 +// AFFINE: affine.for %[[l:.*]] = 0 to 4 +// AFFINE: affine.for %[[j:.*]] = 0 to 2 +// AFFINE: affine.for %[[k:.*]] = 0 to 3 +// AFFINE: %{{.*}} = addi %[[i]], %[[j]] : index +// AFFINE: %{{.*}} = addi %{{.*}}, %[[k]] : index +// AFFINE: %{{.*}} = addi %{{.*}}, %[[l]] : index +// AFFINE: %{{.*}} = addi %{{.*}}, %[[m]] : index diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -880,6 +880,61 @@ library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))" } +func @generic_index_region( + %arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.generic #trait4 + ins(%arg0 : memref) + outs(%arg1, %arg2 : memref, + memref) { + ^bb0(%a: f32, %b: f32, %c: f32): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %k = linalg.index 2 : index + %result_1 = mulf %a, %b : f32 + + %ij = addi %i, %j : index + %ijk = addi %ij, %k : index + %ijk_int = index_cast %ijk : index to i32 + %ijk_float = sitofp %ijk_int : i32 to f32 + + %result_2 = addf %c, %ijk_float : f32 + linalg.yield %result_1, %result_2 : f32, f32 + } + return +} + +// CHECKLOOP-LABEL: @generic_index_region +// CHECKLOOP: scf.for %[[i:.*]] = {{.*}} +// CHECKLOOP: scf.for %[[j:.*]] = {{.*}} +// CHECKLOOP: scf.for %[[k:.*]] = {{.*}} +// CHECKLOOP: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]] +// CHECKLOOP: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]] +// CHECKLOOP: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]] +// CHECKLOOP: %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32 +// CHECKLOOP: %[[ij:.*]] = addi %[[i]], %[[j]] : index +// CHECKLOOP: %[[ijk:.*]] = addi %[[ij]], %[[k]] : index +// CHECKLOOP: %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32 +// CHECKLOOP: %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32 +// CHECKLOOP: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32 +// CHECKLOOP: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]] +// CHECKLOOP: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]] + +// CHECKPARALLEL-LABEL: @generic_index_region +// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]]) +// CHECKPARALLEL: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]] +// CHECKPARALLEL: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]] +// CHECKPARALLEL: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]] +// CHECKPARALLEL: %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32 +// CHECKPARALLEL: %[[ij:.*]] = addi %[[i]], %[[j]] : index +// CHECKPARALLEL: %[[ijk:.*]] = addi %[[ij]], %[[k]] : index +// CHECKPARALLEL: %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32 +// CHECKPARALLEL: %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32 +// CHECKPARALLEL: %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32 +// CHECKPARALLEL: store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]] +// CHECKPARALLEL: store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]] + func @indexed_generic_region( %arg0: memref, %arg1: memref, @@ -973,6 +1028,43 @@ // CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][] // CHECKPARALLEL: store %[[a]], %[[ARG1]][%[[i]], %[[j]]] +func @generic_index_op_zero_rank(%arg0: memref, %arg1: memref<3x4xi32>) +{ + linalg.generic #trait_broadcast + ins(%arg0 : memref) + outs(%arg1 : memref<3x4xi32>) { + ^bb(%a: i32, %b: i32) : + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %ij = addi %i, %j : index + %ij_int = index_cast %ij : index to i32 + %result = addi %a, %ij_int : i32 + linalg.yield %result : i32 + } + return +} + +// CHECKLOOP-LABEL: @generic_index_op_zero_rank +// CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32> +// CHECKLOOP: scf.for %[[i:.*]] = {{.*}} +// CHECKLOOP: scf.for %[[j:.*]] = {{.*}} +// CHECKLOOP: %[[a:.*]] = memref.load %[[ARG0]][ +// CHECKLOOP: %[[ij:.*]] = addi %[[i]], %[[j]] : index +// CHECKLOOP: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32 +// CHECKLOOP: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32 +// CHECKLOOP: store %[[result]], %[[ARG1]][%[[i]], %[[j]]] + +// CHECKPARALLEL-LABEL: @generic_index_op_zero_rank +// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32> +// CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]]) +// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][ +// CHECKPARALLEL: %[[ij:.*]] = addi %[[i]], %[[j]] : index +// CHECKPARALLEL: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32 +// CHECKPARALLEL: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32 +// CHECKPARALLEL: store %[[result]], %[[ARG1]][%[[i]], %[[j]]] + func @indexed_generic_op_zero_rank(%arg0: memref, %arg1: memref<3x4xi32>) { linalg.indexed_generic #trait_broadcast @@ -1065,6 +1157,47 @@ library_call = "some_reduce_external_fn" } +func @generic_index_op_1D_reduce(%arg0: memref, + %arg1: memref, + %arg2: memref) +{ + linalg.generic #trait_reduce_init_1D + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) { + ^bb(%a: f32, %b: f32, %c: f32) : + %i = linalg.index 0 : index + %0 = constant 0 : index + %1 = cmpi eq, %0, %i : index + %2 = select %1, %b, %c : f32 + %3 = addf %a, %2 : f32 + linalg.yield %3 : f32 + } + return +} +// CHECKLOOP-LABEL: @generic_index_op_1D_reduce +// CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECKLOOP-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECKLOOP: scf.for %[[i:.*]] = {{.*}} +// CHECKLOOP: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]] +// CHECKLOOP: %[[b:.*]] = memref.load %[[ARG1]][] +// CHECKLOOP: %[[c:.*]] = memref.load %[[ARG2]][] +// CHECKLOOP: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]] +// CHECKLOOP: %[[e:.*]] = addf %[[a]], %[[d]] +// CHECKLOOP: store %[[e]], %[[ARG2]][] + +// CHECKPARALLEL-LABEL: @generic_index_op_1D_reduce +// CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECKPARALLEL-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECKPARALLEL: scf.for %[[i:.*]] = {{.*}} +// CHECKPARALLEL: %[[a:.*]] = memref.load %[[ARG0]][%[[i]]] +// CHECKPARALLEL: %[[b:.*]] = memref.load %[[ARG1]][] +// CHECKPARALLEL: %[[c:.*]] = memref.load %[[ARG2]][] +// CHECKPARALLEL: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]] +// CHECKPARALLEL: %[[e:.*]] = addf %[[a]], %[[d]] +// CHECKPARALLEL: store %[[e]], %[[ARG2]][] + func @indexed_generic_op_1D_reduce(%arg0: memref, %arg1: memref, %arg2: memref)