diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -413,13 +413,13 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, unsigned fusedTensorIndex) { // Is fusable only if: - // - The linalgOp is a generic op. + // - The linalgOp is a generic op, or an indexed_generic. // - All the indexing maps for operands in linalgOp are projected // permutations. // - The indexing map at the position representing the fused tensor is a // permutation. // - All the loops in linalgOp are parallel loops. - return isa(linalgOp.getOperation()) && + return isa(linalgOp.getOperation()) && linalgOp.hasTensorSemantics() && llvm::all_of(linalgOp.indexing_maps().getValue().take_front( linalgOp.getNumInputs()), @@ -460,7 +460,7 @@ ArrayRef expandedShape = expandedType.getShape(); SmallVector numFoldedDims(foldedType.getRank(), 0); SmallVector, 4> expandedDimsShape( - expandedType.getRank()); + foldedType.getRank()); auto reassociationMaps = reshapeOp.getReassociationMaps(); for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { unsigned pos = resultExpr.value().cast().getPosition(); @@ -472,6 +472,26 @@ expandedDimsShape[pos].assign(shape.begin(), shape.end()); } + if (isa(linalgOp.getOperation())) { + // For indexed generic op, the region contains arguments that represent the + // induction variable value of the loops. In the fused op these values are + // obtained by linearizing the expanded dimensions. For now just check that + // the extents used in the linearization (all the expanded dims except the + // front) are statically know. For dynamic case, we would need shape + // information on these dimensions to get these. + for (auto &expandedShape : expandedDimsShape) { + for (int64_t expandedDimShape : llvm::make_range( + std::next(expandedShape.begin()), expandedShape.end())) { + if (ShapedType::isDynamic(expandedDimShape)) { + linalgOp.emitError( + "unable to fuse indexed generic op where the expanded dim is " + "dynamic"); + return llvm::None; + } + } + } + } + // The remapping of the indices is then the prefix sum (inclusive) of the // numFoldedDims. SmallVector remapping(numFoldedDims.size() + 1, 0); @@ -563,10 +583,56 @@ /*outputBuffers=*/ValueRange{}, /*initTensors=*/ValueRange{}, expandedOpIndexingMaps, iteratorTypes); Region &fusedRegion = fusedOp.getOperation()->getRegion(0); - // TODO: Add support for indexed generic op, which would need mapping the - // expanded dimensions to the original dimension arguments. - rewriter.cloneRegionBefore(linalgOp.getOperation()->getRegion(0), fusedRegion, - fusedRegion.begin()); + Region &originalRegion = linalgOp.getOperation()->getRegion(0); + + if (isa(linalgOp.getOperation())) { + rewriter.cloneRegionBefore(originalRegion, fusedRegion, + fusedRegion.begin()); + } else { + assert(isa(linalgOp.getOperation())); + // Create an entry block in the fused Region with same number of arguments + // as the fused op + Block *fusedEntryBlock = new Block; + fusedRegion.push_back(fusedEntryBlock); + rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.end()); + + // Merge the entry block of the fused op with the cloned blocks. For this + // compute the value for arguments of the region in the original operation + // in terms of the arguments of the fused op. Since the original operation + // is expanded, the expanded dimensions need to be folded back to get the + // replacement value for the arguments corresponding to interation index. + // For now this expects that all the loop ranges are constants, which is + // true if the shapes are all static. This has already been checked in the + // precondition. + using namespace edsc::op; + using namespace edsc::intrinsics; + OpBuilder::InsertionGuard guard(rewriter); + SmallVector argReplacements(originalRegion.getNumArguments()); + rewriter.setInsertionPointToStart(fusedEntryBlock); + edsc::ScopedContext scopedContext(rewriter, fusedOp.getLoc()); + IndexType indexType = rewriter.getIndexType(); + for (unsigned i : llvm::seq(0, numFoldedDims.size())) { + Value linearizedIndex = fusedEntryBlock->addArgument(indexType); + for (unsigned foldedDim = remapping[i] + 1; foldedDim != remapping[i + 1]; + foldedDim++) { + int64_t expandedDimExtent = + expandedDimsShape[i][foldedDim - remapping[i]]; + assert(!ShapedType::isDynamic(expandedDimExtent)); + linearizedIndex = + linearizedIndex * std_constant_index(expandedDimExtent); + linearizedIndex = + linearizedIndex + fusedEntryBlock->addArgument(indexType); + } + argReplacements[i] = linearizedIndex; + } + for (unsigned i : + llvm::seq(numFoldedDims.size(), argReplacements.size())) { + argReplacements[i] = + fusedEntryBlock->addArgument(originalRegion.getArgument(i).getType()); + } + rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock, + argReplacements); + } // Reshape the result values to their original shape if this is a collapsing // reshape folded into its consumer. @@ -670,14 +736,15 @@ } }; -/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the -/// reshape op is collapsing dimensions. The dimensionality of the loop in the -/// consumer generic op is expanded. +/// Pattern to fuse a tensor_reshape op with its consumer +/// generic/indexed_generic op, when the reshape op is collapsing +/// dimensions. The dimensionality of the loop in the consumer is expanded. +template struct FoldWithProducerReshapeOpByExpansion - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOp genericOp, + LogicalResult matchAndRewrite(GenericOpTy genericOp, PatternRewriter &rewriter) const override { LinalgOp linalgOp = cast(genericOp.getOperation()); for (auto operand : llvm::enumerate(linalgOp.getInputs())) { @@ -942,7 +1009,9 @@ void mlir::populateFoldReshapeOpsByExpansionPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert(context); + FoldWithProducerReshapeOpByExpansion, + FoldWithProducerReshapeOpByExpansion>( + context); } void mlir::populateLinalgTensorOpsFusionPatterns( diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -190,3 +190,157 @@ // CHECK-SAME: iterator_types = ["parallel", "parallel"] // CHECK-SAME: ins(%[[T0]] : tensor) // CHECK: return %[[T1]] : tensor<1x10xf32> + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.indexed_generic { + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) { + ^bb0(%arg3 : index, %arg4 : index, %arg5 : index, %arg6: i32, %arg7: i32): + %1 = muli %arg6, %arg7 : i32 + %2 = index_cast %arg3 : index to i32 + %3 = addi %1, %2 : i32 + %4 = index_cast %arg4 : index to i32 + %5 = addi %3, %4 : i32 + %6 = index_cast %arg5 : index to i32 + %7 = addi %5, %6 : i32 + linalg.yield %7 : i32 + } -> tensor + return %1 : tensor +} + +// The generic op version of the test check for the op structure. Only +// checking the op body here. +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)> +// CHECK: func @indexed_generic_op_reshape_producer_fusion +// CHECK: linalg.indexed_generic +// CHECK: ^{{.*}}( +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32) +// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG2]], %[[ARG3]]) +// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]] +// CHECK: %[[T5:.+]] = index_cast %[[T3]] +// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] +// CHECK: %[[T7:.+]] = index_cast %[[ARG4]] +// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]] +// CHECK: %[[T9:.+]] = index_cast %[[ARG5]] +// CHECK: %[[T10:.+]] = addi %[[T8]], %[[T9]] +// CHECK: linalg.yield %[[T10]] + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.indexed_generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) { + ^bb0(%arg3 : index, %arg4 : index, %arg5: i32, %arg6: i32): // no predecessors + %1 = muli %arg5, %arg6 : i32 + %2 = index_cast %arg3 : index to i32 + %3 = addi %1, %2 : i32 + %4 = index_cast %arg4 : index to i32 + %5 = addi %3, %4 : i32 + linalg.yield %5 : i32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} +// The generic op version of the test check for the op structure. Only +// checking the op body here. +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 * 20 + d1 * 5 + d2)> +// CHECK: func @indexed_generic_op_reshape_consumer_fusion +// CHECK: linalg.indexed_generic +// CHECK: ^{{.*}}( +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32) +// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG3]], %[[ARG4]], %[[ARG5]]) +// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]] +// CHECK: %[[T5:.+]] = index_cast %[[ARG2]] +// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] +// CHECK: %[[T7:.+]] = index_cast %[[T3]] +// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]] +// CHECK: linalg.yield %[[T8]] + +// ----- + +func @reshape_as_consumer_permutation + (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>) + -> tensor<2x3x4x5x6x7xi32> { + %c = linalg.indexed_generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2, d1)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>) { + ^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32): + %1 = addi %arg3, %arg4 : i32 + %2 = index_cast %arg0 : index to i32 + %3 = addi %1, %2 : i32 + %4 = index_cast %arg1 : index to i32 + %5 = addi %3, %4 : i32 + %6 = index_cast %arg2 : index to i32 + %7 = addi %5, %6 : i32 + linalg.yield %7 : i32 + } -> tensor<6x4x210xi32> + %d = linalg.tensor_reshape %c + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] + : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> + return %d : tensor<2x3x4x5x6x7xi32> +} + + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 * 3 + d1)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2) -> (d0 * 42 + d1 * 7 + d2)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> +// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK: func @reshape_as_consumer_permutation +// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> +// CHECK-DAG: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-SAME: [#[[MAP3]], #[[MAP4]]] +// CHECK: %[[T2:.+]] = linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[MAP7]], #[[MAP8]], #[[MAP9]]] +// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<{{.+}}>, tensor<{{.+}}>) +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index, %[[ARG7:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32) +// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP5]](%[[ARG2]], %[[ARG3]]) +// CHECK-DAG: %[[T4:.+]] = affine.apply #[[MAP6]](%[[ARG4]], %[[ARG5]], %[[ARG6]]) +// CHECK-DAG: %[[T5:.+]] = addi %[[ARG8]], %[[ARG9]] +// CHECK: %[[T6:.+]] = index_cast %[[T3]] +// CHECK: %[[T7:.+]] = addi %[[T5]], %[[T6]] +// CHECK: %[[T8:.+]] = index_cast %[[T4]] +// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]] +// CHECK: %[[T10:.+]] = index_cast %[[ARG7]] +// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]]