diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -181,6 +181,9 @@ .getTypes(); LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); + // Shift all IndexOp results by the tile offset. + addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); + return clonedOp; } @@ -325,10 +328,6 @@ if (!producerResult || !isa(producerResult.getOwner())) return failure(); - // TODO: support producers that have index semantics. - if (cast(producerResult.getOwner()).hasIndexSemantics()) - return failure(); - // Compute the slice dimensions tiled by `tileLoopNest`. SmallVector tiledSliceDims = getTiledSliceDims(producerResult, rootOpOperand, loopDims); diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -188,3 +188,45 @@ %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32> return %2 : tensor<24x25xf32> } + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +#map0 = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK: fuse_indexed +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xi32> +builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>, + %arg1: tensor<12x25xi32>, + %arg2: tensor<24x25xi32>) -> tensor<24x25xi32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %0 = linalg.generic {indexing_maps = [#map0], iterator_types = ["parallel", "parallel"]} outs(%arg1 : tensor<12x25xi32>) { + ^bb0(%arg3: i32): // no predecessors + %6 = linalg.index 0 : index + %7 = linalg.index 1 : index + %8 = addi %6, %7 : index + %9 = index_cast %8 : index to i32 + linalg.yield %9 : i32 + } -> tensor<12x25xi32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + + // Shift the indexes by the slice offsets and swap the offsets due to the transposed indexing map. + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: %[[IV2]], %[[IV0]] + // CHECK: linalg.generic {{.*}} outs(%[[T1]] + // CHECK: %[[IDX0:.*]] = linalg.index 0 + // CHECK: %[[IDX0_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX0]], %[[IV0]]) + // CHECK: %[[IDX1:.*]] = linalg.index 1 + // CHECK: %[[IDX1_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX1]], %[[IV2]]) + // CHECK: %{{.*}} = addi %[[IDX0_SHIFTED]], %[[IDX1_SHIFTED]] + %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xi32>, tensor<12x25xi32>) outs(%arg2 : tensor<24x25xi32>) -> tensor<24x25xi32> + return %1 : tensor<24x25xi32> +} +