diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -145,6 +145,31 @@ }))); } +/// Update the index accesses of linalg operations having index semantics. +template +static void replaceUnitDimIndexOps(GenericOpTy op, + const DenseSet &unitDims, + PatternRewriter &rewriter) { + assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 && + "expected generic operation to have one block."); + Block &block = op->getRegion(0).front(); + + for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps())) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(indexOp); + if (unitDims.count(indexOp.dim()) != 0) { + rewriter.replaceOpWithNewOp(indexOp, 0); + } else { + // Update the dimension of the index operation if needed. + unsigned droppedDims = llvm::count_if( + unitDims, [&](unsigned dim) { return dim < indexOp.dim(); }); + if (droppedDims != 0) + rewriter.replaceOpWithNewOp(indexOp, + indexOp.dim() - droppedDims); + } + } +} + /// Modify the region of indexed generic op to drop arguments corresponding to /// loops that are unit trip count. template @@ -177,10 +202,6 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOpTy op, PatternRewriter &rewriter) const override { - // TODO: remove once index ops are supported. - if (op.hasIndexSemantics()) - return failure(); - SmallVector indexingMaps = op.getIndexingMaps(); if (indexingMaps.empty()) return failure(); @@ -253,6 +274,7 @@ op.indexing_mapsAttr(newIndexingMapAttr); op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter); + replaceUnitDimIndexOps(op, unitDims, rewriter); rewriter.finalizeRootUpdate(op); return success(); } @@ -325,10 +347,6 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOpTy op, PatternRewriter &rewriter) const override { - // TODO: remove once index ops are supported. - if (op.hasIndexSemantics()) - return failure(); - if (!op.hasTensorSemantics()) return failure(); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -78,6 +78,56 @@ // ----- +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @drop_one_trip_loops_indexed + (%arg0 : tensor, %shape: tensor) -> tensor +{ + %0 = linalg.generic #trait + ins(%arg0 : tensor) + outs(%shape: tensor) { + ^bb0(%arg6 : i32, %arg7 : i32) : + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %idx2 = linalg.index 2 : index + %idx3 = linalg.index 3 : index + %idx4 = linalg.index 4 : index + %1 = addi %idx0, %idx1 : index + %2 = subi %1, %idx2 : index + %3 = subi %2, %idx3 : index + %4 = addi %3, %idx4 : index + %5 = index_cast %4 : index to i32 + %6 = addi %5, %arg6 : i32 + linalg.yield %6 : i32 + } -> tensor + return %0 : tensor +} +// The subtractions disappear the access map of the output tensor maps its unit +// dimensions 1 and 3 to the index dimensions 2 and 3. +// CHECK-LABEL: func @drop_one_trip_loops_indexed +// CHECK: linalg.generic +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32) +// CHECK: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK: %[[IDX2:.+]] = linalg.index 2 : index +// CHECK: %[[T3:.+]] = addi %[[IDX0]], %[[IDX1]] +// CHECK: %[[T4:.+]] = addi %[[T3]], %[[IDX2]] +// CHECK: %[[T5:.+]] = index_cast %[[T4]] : index to i32 +// CHECK: %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32 +// CHECK: linalg.yield %[[T6]] : i32 + +// ----- + #map0 = affine_map<(i, j) -> (i, j)> #access = [#map0, #map0] #trait = { @@ -134,6 +184,37 @@ // ----- +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops_indexed + (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{ + %0 = linalg.generic #trait + ins(%arg0 : tensor<1x1xi32>) + outs(%arg0 : tensor<1x1xi32>) { + ^bb0(%arg3: i32, %arg4: i32) : + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %1 = addi %idx0, %idx1 : index + %2 = index_cast %1 : index to i32 + %3 = addi %2, %arg3 : i32 + linalg.yield %3 : i32 + } -> tensor<1x1xi32> + return %0 : tensor<1x1xi32> +} + +// CHECK-LABEL: func @drop_all_loops_indexed +// CHECK: linalg.generic +// CHECK: ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) +// CHECK: linalg.yield %[[ARG1]] : i32 + +// ----- + #accesses = [ affine_map<(d0) -> (0, d0)>, affine_map<(d0) -> (d0)> @@ -566,19 +647,3 @@ // CHECK-SAME: outs(%[[FILL]] : tensor) // CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]] // CHECK: return %[[RESULT_RESHAPE]] - -// ----- - -// CHECK: #{{.+}} = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: @index_op -func @index_op(%arg0: memref<1x8xindex>) { - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - outs(%arg0 : memref<1x8xindex>) { - ^bb0(%arg1: index): // no predecessors - %0 = linalg.index 1 : index - linalg.yield %0 : index - } - return -}