diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -82,96 +82,6 @@ return std::make_tuple(res, loopIndexToRangeIndex); } -// IndexedGenericOp explicitly uses induction variables in the loop body. The -// values of the indices that are used in the loop body for any given access of -// input/output memref before `subview` op was applied should be invariant with -// respect to tiling. -// -// Therefore, if the operation is tiled, we have to transform the indices -// accordingly, i.e. offset them by the values of the corresponding induction -// variables that are captured implicitly in the body of the op. -// -// Example. `linalg.indexed_generic` before tiling: -// -// #id_2d = (i, j) -> (i, j) -// #pointwise_2d_trait = { -// indexing_maps = [#id_2d, #id_2d], -// iterator_types = ["parallel", "parallel"], -// n_views = [1, 1] -// } -// linalg.indexed_generic #pointwise_2d_trait %operand, %result { -// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): -// -// }: memref<50x100xf32>, memref<50x100xf32> -// -// After tiling pass with tiles sizes 10 and 25: -// -// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) -// -// %c1 = constant 1 : index -// %c0 = constant 0 : index -// %c25 = constant 25 : index -// %c10 = constant 10 : index -// operand_dim_0 = dim %operand, 0 : memref<50x100xf32> -// operand_dim_1 = dim %operand, 1 : memref<50x100xf32> -// scf.for %k = %c0 to operand_dim_0 step %c10 { -// scf.for %l = %c0 to operand_dim_1 step %c25 { -// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1] -// : memref<50x100xf32> to memref -// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1] -// : memref<50x100xf32> to memref -// linalg.indexed_generic pointwise_2d_trait %4, %5 { -// ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): -// // Indices `k` and `l` are implicitly captured in the body. -// %transformed_i = addi %i, %k : index // index `i` is offset by %k -// %transformed_j = addi %j, %l : index // index `j` is offset by %l -// // Every use of %i, %j is replaced with %transformed_i, %transformed_j -// -// }: memref, memref -// } -// } -// -// TODO: Investigate whether mixing implicit and explicit indices -// does not lead to losing information. -static void transformIndexedGenericOpIndices( - OpBuilder &b, LinalgOp op, SmallVectorImpl &ivs, - const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { - auto indexedGenericOp = dyn_cast(op.getOperation()); - if (!indexedGenericOp) - return; - - // `linalg.indexed_generic` comes in two flavours. One has a region with a - // single block that defines the loop body. The other has a `fun` attribute - // that refers to an existing function symbol. The `fun` function call will be - // inserted in the loop body in that case. - // - // TODO: Add support for `linalg.indexed_generic` with `fun` attribute. - auto ®ion = indexedGenericOp.region(); - if (region.empty()) { - indexedGenericOp.emitOpError("expected a region"); - return; - } - auto &block = region.front(); - - OpBuilder::InsertionGuard g(b); - b.setInsertionPointToStart(&block); - for (unsigned i = 0; i < indexedGenericOp.getNumLoops(); ++i) { - auto rangeIndex = loopIndexToRangeIndex.find(i); - if (rangeIndex == loopIndexToRangeIndex.end()) - continue; - Value oldIndex = block.getArgument(i); - // Offset the index argument `i` by the value of the corresponding induction - // variable and replace all uses of the previous value. - Value newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, - ivs[rangeIndex->second]); - for (auto &use : oldIndex.getUses()) { - if (use.getOwner() == newIndex.getDefiningOp()) - continue; - use.set(newIndex); - } - } -} - // All indices returned by IndexOp should be invariant with respect to tiling. // Therefore, if an operation is tiled, we have to transform the indices // accordingly, i.e. offset them by the values of the corresponding induction @@ -261,6 +171,10 @@ if (llvm::all_of(tileSizes, isZero)) return llvm::None; + // Canonicalize indexed generic operations before tiling. + if (isa(op)) + return llvm::None; + if (auto convOp = dyn_cast(op.getOperation())) { // For conv op only support tiling along batch dimension (which is the first // loop). @@ -376,9 +290,7 @@ }, options.distribution); - // 3a. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. - transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); - // 3b. Transform IndexOp results w.r.t. the tiling. + // 3. Transform IndexOp results w.r.t. the tiling. transformIndexOps(b, res, ivs, loopIndexToRangeIndex); // 4. Gather the newly created loops and return them with the new op. @@ -521,7 +433,7 @@ /// Populate the given list with patterns that apply Linalg tiling. static void insertTilingPatterns(RewritePatternSet &patterns, const LinalgTilingOptions &options) { - RewritePatternList::insert(patterns, options); diff --git a/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir b/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir +++ /dev/null @@ -1,116 +0,0 @@ -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=10,25" | FileCheck %s -check-prefix=TILE-10n25 -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=25,0" | FileCheck %s -check-prefix=TILE-25n0 -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,25" | FileCheck %s -check-prefix=TILE-0n25 - -#id_1d = affine_map<(i) -> (i)> -#pointwise_1d_trait = { - args_in = 1, - args_out = 1, - indexing_maps = [#id_1d, #id_1d], - iterator_types = ["parallel"] -} -func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) { - linalg.indexed_generic #pointwise_1d_trait - ins(%operand :memref<50xf32>) - outs(%result : memref<50xf32>) { - ^bb0(%i: index, %operand_in: f32, %result_in: f32): - %i_int = index_cast %i: index to i32 - %i_float = sitofp %i_int : i32 to f32 - %out = addf %operand_in, %i_float : f32 - linalg.yield %out : f32 - } - return -} -// TILE-10n25-LABEL: func @indexed_generic_vector -// TILE-10n25: %[[C10:.*]] = constant 10 : index -// TILE-10n25: scf.for %[[J:.*]] = {{.*}} step %[[C10]] -// TILE-10n25: linalg.generic -// TILE-10n25: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32) -// TILE-10n25: %[[I:.*]] = linalg.index 0 : index -// TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index -// TILE-10n25: %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32 -// TILE-10n25: %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32 -// TILE-10n25: %[[OUT:.*]] = addf %[[IN]], %[[NEW_I_FLOAT]] : f32 - -// TILE-25n0-LABEL: func @indexed_generic_vector -// TILE-25n0: %[[C25:.*]] = constant 25 : index -// TILE-25n0: scf.for %[[J:.*]] = {{.*}} step %[[C25]] -// TILE-25n0: linalg.generic -// TILE-25n0: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32) -// TILE-25n0: %[[I:.*]] = linalg.index 0 : index -// TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[J]] : index -// TILE-25n0: %[[NEW_I_INT:.*]] = index_cast %[[NEW_I]] : index to i32 -// TILE-25n0: %[[NEW_I_FLOAT:.*]] = sitofp %[[NEW_I_INT]] : i32 to f32 -// TILE-25n0: %[[OUT:.*]] = addf %[[IN]], %[[NEW_I_FLOAT]] : f32 - -// TILE-0n25-LABEL: func @indexed_generic_vector -// TILE-0n25-NOT: scf.for %[[J:.*]] = {{.*}} step % -// TILE-0n25: linalg.generic - -#combined_indices_trait = { - args_in = 1, - args_out = 1, - indexing_maps = [ - affine_map<(i, j) -> (j, i + j)>, - affine_map<(i, j) -> (i, j)> - ], - iterator_types = ["parallel", "parallel"] -} -func @indexed_generic_matrix(%operand: memref<50x99xf32>, %result: memref<50x50xf32>) { - linalg.indexed_generic #combined_indices_trait - ins(%operand : memref<50x99xf32>) - outs(%result : memref<50x50xf32>) { - ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): - %i_int = index_cast %i: index to i32 - %i_float = sitofp %i_int : i32 to f32 - %j_int = index_cast %j: index to i32 - %j_float = sitofp %j_int : i32 to f32 - %out = addf %i_float, %j_float : f32 - linalg.yield %out : f32 - } - return -} -// TILE-10n25-LABEL: func @indexed_generic_matrix -// TILE-10n25-DAG: %[[C25:.*]] = constant 25 : index -// TILE-10n25-DAG: %[[C10:.*]] = constant 10 : index -// TILE-10n25: scf.for %[[K:.*]] = {{.*}} step %[[C10]] -// TILE-10n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]] -// TILE-10n25: linalg.generic -// TILE-10n25: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// TILE-10n25: %[[I:.*]] = linalg.index 0 : index -// TILE-10n25: %[[J:.*]] = linalg.index 1 : index -// TILE-10n25: %[[NEW_I:.*]] = addi %[[I]], %[[K]] : index -// TILE-10n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index -// TILE-10n25: %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32 -// TILE-10n25: %[[NEW_FLOAT_I:.*]] = sitofp %[[NEW_INT_I]] : i32 to f32 -// TILE-10n25: %[[NEW_INT_J:.*]] = index_cast %[[NEW_J]] : index to i32 -// TILE-10n25: %[[NEW_FLOAT_J:.*]] = sitofp %[[NEW_INT_J]] : i32 to f32 -// TILE-10n25: %[[OUT:.*]] = addf %[[NEW_FLOAT_I]], %[[NEW_FLOAT_J]] : f32 - -// TILE-25n0-LABEL: func @indexed_generic_matrix -// TILE-25n0: %[[C25:.*]] = constant 25 : index -// TILE-25n0: scf.for %[[L:.*]] = {{.*}} step %[[C25]] -// TILE-25n0: linalg.generic -// TILE-25n0: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// TILE-25n0: %[[I:.*]] = linalg.index 0 : index -// TILE-25n0: %[[J:.*]] = linalg.index 1 : index -// TILE-25n0: %[[NEW_I:.*]] = addi %[[I]], %[[L]] : index -// TILE-25n0: %[[NEW_INT_I:.*]] = index_cast %[[NEW_I]] : index to i32 -// TILE-25n0: %[[NEW_FLOAT_I:.*]] = sitofp %[[NEW_INT_I]] : i32 to f32 -// TILE-25n0: %[[INT_J:.*]] = index_cast %[[J]] : index to i32 -// TILE-25n0: %[[FLOAT_J:.*]] = sitofp %[[INT_J]] : i32 to f32 -// TILE-25n0: %[[OUT:.*]] = addf %[[NEW_FLOAT_I]], %[[FLOAT_J]] : f32 - -// TILE-0n25-LABEL: func @indexed_generic_matrix -// TILE-0n25: %[[C25:.*]] = constant 25 : index -// TILE-0n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]] -// TILE-0n25: linalg.generic -// TILE-0n25: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): -// TILE-0n25: %[[I:.*]] = linalg.index 0 : index -// TILE-0n25: %[[J:.*]] = linalg.index 1 : index -// TILE-0n25: %[[NEW_J:.*]] = addi %[[J]], %[[L]] : index -// TILE-0n25: %[[INT_I:.*]] = index_cast %[[I]] : index to i32 -// TILE-0n25: %[[FLOAT_I:.*]] = sitofp %[[INT_I]] : i32 to f32 -// TILE-0n25: %[[NEW_INT_J:.*]] = index_cast %[[NEW_J]] : index to i32 -// TILE-0n25: %[[NEW_FLOAT_J:.*]] = sitofp %[[NEW_INT_J]] : i32 to f32 -// TILE-0n25: %[[OUT:.*]] = addf %[[FLOAT_I]], %[[NEW_FLOAT_J]] : f32 diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -131,53 +131,6 @@ // ----- -func @indexed_generic_op_tensors( - %arg0 : tensor, %arg1 : tensor) -> tensor { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - %0 = memref.dim %arg0, %c0 : tensor - %1 = memref.dim %arg0, %c1 : tensor - %2 = memref.dim %arg0, %c2 : tensor - %3 = linalg.init_tensor [%0, %1, %2] : tensor - %4 = linalg.indexed_generic - {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d2, d1)>, - affine_map<(d0, d1, d2) -> (d2, d1, d0)>], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%3 : tensor) { - ^bb0(%arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32, %arg6: f32, %arg7: f32): - %5 = addf %arg5, %arg6 : f32 - linalg.yield %5 : f32 - } -> tensor - return %4 : tensor -} - -// CHECK-LABEL: func @indexed_generic_op_tensors -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[INIT:.+]] = linalg.init_tensor -// CHECK: %[[TD0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[TC0:.+]] = %[[INIT]]) -> (tensor) { -// CHECK: %[[TD1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[TC1:.+]] = %[[TC0]]) -> (tensor) { -// CHECK: %[[TD2:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[TC2:.+]] = %[[TC1]]) -> (tensor) { -// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][{{.+}}] : tensor to tensor -// CHECK: %[[STARG1:.+]] = subtensor %[[ARG1]][{{.+}}] : tensor to tensor -// CHECK: %[[STARG2:.+]] = subtensor %[[TC2]][{{.+}}] : tensor to tensor -// CHECK: %[[STRETURN:.+]] = linalg.generic -// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[STARG2]] : tensor) -// CHECK: %[[TD:.+]] = subtensor_insert %[[STRETURN]] into %[[TC2]] -// CHECK: scf.yield %[[TD]] -// CHECK: } -// CHECK: scf.yield %[[TD2]] -// CHECK: } -// CHECK: scf.yield %[[TD1]] -// CHECK: } -// CHECK: return %[[TD0]] - -// ----- - func @fill_tensors(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor { %0 = linalg.init_tensor [%arg0, %arg1] : tensor %1 = linalg.fill(%0, %arg2) : tensor, f32 -> tensor