diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -28,10 +28,6 @@ /// Implementation of fusion of generic ops and indexed_generic ops. static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx) { - // TODO: remove once index ops are supported. - if (producer.hasIndexSemantics() || consumer.hasIndexSemantics()) - return false; - // Producer and consumer must have tensor semantics. if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) return false; @@ -138,7 +134,7 @@ // 1. Map consumer indices to fusedBlock indices 1-1. mapper.map(consumerBlock.getArguments().take_front(numConsumerIndices), fusedBlock->getArguments().take_front(numConsumerIndices)); - // 2. Embed producer indices into fusedBlock index space 1-1. + // 2a. Embed producer indices into fusedBlock index space 1-1. for (auto it : llvm::zip(producerBlock.getArguments().take_front(numProducerIndices), fusedBlock->getArguments().take_front(numProducerIndices))) { @@ -148,6 +144,28 @@ fusedBlock->getArguments().take_front(numFusedOpIndices)); mapper.map(std::get<0>(it), newIndex); } + // 2b. Replace the producer index operations by index operations placed in the + // fused block using the `consumerToProducerLoopsMap` to map the index spaces. + unsigned numFusedOpLoops = + std::max(producer.getNumLoops(), consumer.getNumLoops()); + if (producer.hasIndexSemantics()) { + SmallVector fusedIndices; + fusedIndices.reserve(numFusedOpLoops); + llvm::transform(llvm::seq(0, numFusedOpLoops), + std::back_inserter(fusedIndices), [&](int64_t dim) { + return rewriter.create(producer.getLoc(), dim); + }); + for (IndexOp indexOp : + llvm::make_early_inc_range(producerBlock.getOps())) { + Value newIndex = rewriter.create( + producer.getLoc(), + consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); + // Replace the producer index operation by the index value computed in the + // fused block. All remaining operations in the producer block are later + // on cloned to the fused block. + rewriter.replaceOp(indexOp, newIndex); + } + } // TODO: allow fusing the producer of an output operand. assert(consumerIdx < consumer.getNumInputs() && "expected producer of input operand"); @@ -329,8 +347,8 @@ invProducerResultIndexMap.compose(consumerResultIndexMap); generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer, - consumerToProducerLoopsMap, consumerIdx, - consumer.getNumLoops()); + consumerToProducerLoopsMap, consumerIdx, + consumer.getNumLoops()); return SmallVector(fusedOp->getResults()); } @@ -602,17 +620,16 @@ return success(); } -/// To expand an indexed_generic operation, the body of the indexed generic op -/// need to be modified appropriately. Specifically, uses of arguments for -/// induction variables in the original operation need to be replaced with -/// linearization of the corresponding arguments in the expanded op. That -/// requires the shape of the expanded dimensions (at least all but the most -/// significant. For now check that these are all statically sized. Note that -/// this could be extended to handle dynamic case, but the implementation below -/// uses `affine.apply` which seems to have issues when the shapes are not -/// static. -LogicalResult isIndexedGenericOpExpandable(LinalgOp linalgOp, - const ExpansionInfo &expansionInfo) { +/// Epanding the body of a linalg operation requires adaptations of the accessed +/// loop indices. Specifically, access of indices in the original operation need +/// to be replaced with linearizations of indices in the expanded op. That +/// requires the shape of the expanded dimensions to be static (at least all but +/// the most significant). For now check that these are all statically sized. +/// Note that this could be extended to handle dynamic case, but the +/// implementation below uses `affine.apply` which seems to have issues when the +/// shapes are not static. +LogicalResult isIndexedOpExpandable(LinalgOp linalgOp, + const ExpansionInfo &expansionInfo) { for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { ArrayRef expandedShape = expansionInfo.getExpandedShapeOfDim(i); if (expandedShape.size() == 1) @@ -734,6 +751,49 @@ argReplacements); } +/// Update the body of an expanded linalg operation having index semantics. The +/// indices of the original operation need to be recovered by linearizing the +/// indices of the correspoding dimensions of the expanded operation. For now it +/// is assumed that the shapes of the expanded operation needed for +/// linearization are static. +static void updateExpandedIndexOpRegion(PatternRewriter &rewriter, Location loc, + Region &fusedRegion, + const ExpansionInfo &expansionInfo) { + // Replace the original indices by the linearization of the expanded indices. + for (IndexOp indexOp : + llvm::make_early_inc_range(fusedRegion.front().getOps())) { + ArrayRef expandedDims = + expansionInfo.getExpandedDims(indexOp.dim()); + assert(!expandedDims.empty() && "expected valid expansion info"); + + // Skip index operations that are not affected by the expansion. + if (expandedDims.size() == 1 && + expandedDims.front() == (int64_t)indexOp.dim()) + continue; + + // Linearize the expanded indices of the original index dimension. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(indexOp); + ArrayRef expandedDimsShape = + expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); + SmallVector expandedIndices; + expandedIndices.reserve(expandedDims.size() - 1); + llvm::transform( + expandedDims.drop_front(), std::back_inserter(expandedIndices), + [&](int64_t dim) { return rewriter.create(loc, dim); }); + Value newIndex = rewriter.create(loc, expandedDims.front()); + for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { + assert(!ShapedType::isDynamic(std::get<0>(it))); + AffineExpr idx, acc; + bindDims(rewriter.getContext(), idx, acc); + newIndex = rewriter.create( + indexOp.getLoc(), idx + acc * std::get<0>(it), + ValueRange{std::get<1>(it), newIndex}); + } + rewriter.replaceOp(indexOp, newIndex); + } +} + /// Implements the fusion of a tensor_reshape op and a generic/indexed_generic /// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those /// conditions have been satisfied. @@ -748,6 +808,8 @@ reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); RankedTensorType expandedType = isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); + bool hasIndexSemantics = linalgOp.hasIndexSemantics() || + isa(linalgOp.getOperation()); ExpansionInfo expansionInfo; if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex, @@ -755,8 +817,8 @@ expandedType.getShape()))) return llvm::None; - if (isa(linalgOp.getOperation()) && - failed(isIndexedGenericOpExpandable(linalgOp, expansionInfo))) + if (hasIndexSemantics && + failed(isIndexedOpExpandable(linalgOp, expansionInfo))) return llvm::None; SmallVector expandedOpIndexingMaps = llvm::to_vector<4>( @@ -823,6 +885,10 @@ fusedRegion, expansionInfo); } + // Update the index accesses after the expansion. + if (linalgOp.hasIndexSemantics()) + updateExpandedIndexOpRegion(rewriter, loc, fusedRegion, expansionInfo); + // Reshape the result values to their original shape if this is a collapsing // reshape folded into its consumer. SmallVector resultVals; @@ -1261,6 +1327,7 @@ context, options.controlElementwiseOpsFusionFn); populateFoldReshapeOpsByExpansionPatterns( patterns, options.allowFoldingUnitDimReshapes); + AffineApplyOp::getCanonicalizationPatterns(patterns, context); GenericOp::getCanonicalizationPatterns(patterns, context); IndexedGenericOp::getCanonicalizationPatterns(patterns, context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -359,6 +359,58 @@ // ----- +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @producer_indexed_consumer_fusion(%arg0: tensor, + %arg1: tensor) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = memref.dim %arg0, %c0 : tensor + %1 = memref.dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"] } + ins(%arg0, %arg1 : tensor, tensor) + outs(%2 : tensor) { + ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors + %10 = addi %arg2, %arg3 : i32 + linalg.yield %10 : i32 + } -> tensor + %4 = linalg.generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"] } + ins(%3 : tensor) + outs(%2 : tensor) { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %5 = index_cast %idx0 : index to i32 + %6 = index_cast %idx1 : index to i32 + %7 = addi %arg2, %5 : i32 + %8 = subi %7, %6 : i32 + linalg.yield %8 : i32 + } -> tensor + return %4 : tensor +} +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @producer_indexed_consumer_fusion +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]] +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32 +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[VAL1:.+]] = addi %[[ARG0]], %[[ARG1]] : i32 +// CHECK: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[IDX0]] : index to i32 +// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[IDX1]] : index to i32 +// CHECK: %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32 +// CHECK: %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32 +// CHECK: linalg.yield %[[VAL3]] : i32 +// CHECK-NOT: linalg.generic + +// ----- + #map0 = affine_map<(d0, d1) -> (d0, d1)> func @indexed_generic_op_generic_op_fusion(%arg0: tensor, %arg1: tensor) -> tensor { @@ -409,6 +461,58 @@ // ----- +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @indexed_producer_consumer_fusion(%arg0: tensor, + %arg1: tensor) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = memref.dim %arg0, %c0 : tensor + %1 = memref.dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"] } + ins(%arg0 : tensor) + outs(%2 : tensor) { + ^bb0(%arg4: i32, %arg5: i32): // no predecessors + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %4 = index_cast %idx0 : index to i32 + %5 = index_cast %idx1 : index to i32 + %6 = addi %arg4, %4 : i32 + %7 = subi %6, %5 : i32 + linalg.yield %7 : i32 + } -> tensor + %4 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"] } + ins(%3, %arg1 : tensor, tensor) + outs(%2 : tensor) { + ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors + %10 = addi %arg2, %arg3 : i32 + linalg.yield %10 : i32 + } -> tensor + return %4 : tensor +} +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @indexed_producer_consumer_fusion +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]] +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32 +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[IDX0]] : index to i32 +// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[IDX1]] : index to i32 +// CHECK: %[[VAL1:.+]] = addi %[[ARG0]], %[[ADD_OPERAND]] : i32 +// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32 +// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG1]] : i32 +// CHECK: linalg.yield %[[VAL3]] : i32 +// CHECK-NOT: linalg.generic + +// ----- + // The indices of the first indexed_generic op are swapped after fusion. #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> @@ -465,6 +569,69 @@ // ----- +// The indices of the first indexed_generic op are swapped after fusion. +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func @indexed_producer_indexed_consumer_fusion(%arg0: tensor) + -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = memref.dim %arg0, %c0 : tensor + %1 = memref.dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.generic { + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"] } + ins(%arg0 : tensor) + outs(%2 : tensor) { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %4 = index_cast %idx0 : index to i32 + %5 = index_cast %idx1 : index to i32 + %6 = addi %arg2, %4 : i32 + %7 = subi %5, %6 : i32 + linalg.yield %7 : i32 + } -> tensor + %4= linalg.generic { + indexing_maps = [#map1, #map1], + iterator_types = ["parallel", "parallel"] } + ins(%3 : tensor) + outs(%2 : tensor) { + ^bb0(%arg2: i32, %arg3: i32): // no predecessors + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %5 = index_cast %idx0 : index to i32 + %6 = index_cast %idx1 : index to i32 + %7 = addi %arg2, %5 : i32 + %8 = subi %7, %6 : i32 + linalg.yield %8 : i32 + } -> tensor + return %4 : tensor +} +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @indexed_producer_indexed_consumer_fusion +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] +// CHECK: ^{{[a-zA-Z0-9_]*}} +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32 +// CHECK: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[IDX1]] : index to i32 +// CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[IDX0]] : index to i32 +// CHECK: %[[VAL1:.+]] = addi %[[ARG0]], %[[ADD_OPERAND1]] : i32 +// CHECK: %[[VAL2:.+]] = subi %[[SUB_OPERAND1]], %[[VAL1]] : i32 +// CHECK: %[[IDX2:.+]] = linalg.index 0 : index +// CHECK: %[[IDX3:.+]] = linalg.index 1 : index +// CHECK: %[[ADD_OPERAND2:.+]] = index_cast %[[IDX2]] : index to i32 +// CHECK: %[[SUB_OPERAND2:.+]] = index_cast %[[IDX3]] : index to i32 +// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32 +// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32 +// CHECK: linalg.yield %[[VAL4]] : i32 +// CHECK-NOT: linalg.generic + +// ----- + func @scalar_indexed_generic_fusion (%arg0: tensor<5x1x1xf32>, %arg1 : tensor) -> tensor<10xf32> { @@ -507,6 +674,48 @@ // ----- +func @scalar_generic_fusion + (%arg0: tensor<5x1x1xf32>, %arg1 : tensor) -> tensor<10xf32> +{ + %c0 = constant 0 : index + %cst = constant dense<1.000000e+00> : tensor<10xf32> + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic + {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], + iterator_types = []} + ins(%arg1 : tensor) outs(%0 : tensor) { + ^bb0(%arg2: i32, %arg3: f32): // no predecessors + %3 = index_cast %arg2 : i32 to index + %4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32> + linalg.yield %4 : f32 + } -> tensor + %2 = linalg.init_tensor [10] : tensor<10xf32> + %3 = linalg.generic + {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%1, %cst : tensor, tensor<10xf32>) outs(%2 : tensor<10xf32>) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + %4 = mulf %arg2, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor<10xf32> + return %3 : tensor<10xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> +// CHECK: func @scalar_generic_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[T0:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[ARG1]] : tensor) +// CHECK: tensor.extract %[[ARG0]] +// CHECK: linalg.yield +// CHECK return %[[T0]] + +// ----- + func @constant_fusion(%arg0 : tensor<4xf32>) -> (tensor<4xf32>) { %cst = constant dense<1.0> : tensor<4xf32> %1 = linalg.init_tensor [4] : tensor<4xf32> @@ -655,32 +864,6 @@ // ----- -// CHECK-LABEL: func @index_op( -// CHECK-COUNT-2: linalg.generic -func @index_op(%arg0: tensor<1x8xindex>, %arg1: tensor<1x8xindex>) -> tensor<1x8xindex> { - %0 = linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - outs(%arg0 : tensor<1x8xindex>) { - ^bb0(%a: index): // no predecessors - %2 = linalg.index 1 : index - linalg.yield %2 : index - } -> tensor<1x8xindex> - %1 = linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], - iterator_types = ["parallel", "parallel"]} - ins(%0 : tensor<1x8xindex>) - outs(%arg1 : tensor<1x8xindex>) { - ^bb0(%a: index, %b: index): // no predecessors - %2 = linalg.index 0 : index - %3 = addi %2, %a : index - linalg.yield %3 : index - } -> tensor<1x8xindex> - return %1 : tensor<1x8xindex> -} - -// ----- - // CHECK-LABEL: func @no_fuse_constant_with_reduction func @no_fuse_constant_with_reduction() -> tensor<3xf32> { diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -237,6 +237,60 @@ // ----- +#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: i32, %arg4: i32, %s: i32): + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %idx2 = linalg.index 2 : index + %1 = muli %arg3, %arg4 : i32 + %2 = index_cast %idx0 : index to i32 + %3 = addi %1, %2 : i32 + %4 = index_cast %idx1 : index to i32 + %5 = addi %3, %4 : i32 + %6 = index_cast %idx2 : index to i32 + %7 = addi %5, %6 : i32 + linalg.yield %7 : i32 + } -> tensor + return %1 : tensor +} + +// Only check the body in the indexed version of the test. +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)> +// CHECK: func @indexed_consumer_reshape_producer_fusion +// CHECK: linalg.generic +// CHECK: ^{{.*}}( +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32, %[[ARG4:[a-zA-Z0-9]+]]: i32, +// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32) +// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index +// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index +// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX1]], %[[IDX0]]) +// CHECK: %[[T4:.+]] = muli %[[ARG3]], %[[ARG4]] +// CHECK: %[[T5:.+]] = index_cast %[[T3]] +// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] +// CHECK: %[[T7:.+]] = index_cast %[[IDX2]] +// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]] +// CHECK: %[[T9:.+]] = index_cast %[[IDX3]] +// CHECK: %[[T10:.+]] = addi %[[T8]], %[[T9]] +// CHECK: linalg.yield %[[T10]] + +// ----- + #map0 = affine_map<(d0, d1) -> (d0, d1)> func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor, %arg1 : tensor) -> @@ -280,6 +334,53 @@ // ----- +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor, + %arg1 : tensor) -> + tensor +{ + %0 = linalg.generic { + indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg0 : tensor) { + ^bb0(%arg3: i32, %arg4: i32, %s: i32): // no predecessors + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %1 = muli %arg3, %arg4 : i32 + %2 = index_cast %idx0 : index to i32 + %3 = addi %1, %2 : i32 + %4 = index_cast %idx1 : index to i32 + %5 = addi %3, %4 : i32 + linalg.yield %5 : i32 + } -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// Only check the body in the indexed version of the test. +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)> +// CHECK: func @indexed_producer_reshape_consumer_fusion +// CHECK: linalg.generic +// CHECK: ^{{.*}}( +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32, %[[ARG4:[a-zA-Z0-9]+]]: i32, +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32) +// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index +// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index +// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX3]], %[[IDX2]], %[[IDX1]]) +// CHECK: %[[T4:.+]] = muli %[[ARG3]], %[[ARG4]] +// CHECK: %[[T5:.+]] = index_cast %[[IDX0]] +// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] +// CHECK: %[[T7:.+]] = index_cast %[[T3]] +// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]] +// CHECK: linalg.yield %[[T8]] + +// ----- + func @reshape_as_consumer_permutation (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>) -> tensor<2x3x4x5x6x7xi32> { @@ -350,6 +451,82 @@ // ----- +func @reshape_as_consumer_permutation + (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>) + -> tensor<2x3x4x5x6x7xi32> { + %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32> + %c = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2, d1)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>) + outs(%shape : tensor<6x4x210xi32>) { + ^bb0(%arg3 : i32, %arg4: i32, %s: i32): + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %idx2 = linalg.index 2 : index + %1 = addi %arg3, %arg4 : i32 + %2 = index_cast %idx0 : index to i32 + %3 = addi %1, %2 : i32 + %4 = index_cast %idx1 : index to i32 + %5 = addi %3, %4 : i32 + %6 = index_cast %idx2 : index to i32 + %7 = addi %5, %6 : i32 + linalg.yield %7 : i32 + } -> tensor<6x4x210xi32> + %d = linalg.tensor_reshape %c + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] + : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> + return %d : tensor<2x3x4x5x6x7xi32> +} + + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> +// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> +// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)> +// CHECK: func @reshape_as_consumer_permutation +// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> +// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]] +// CHECK-SAME: [#[[MAP3]], #[[MAP4]]] +// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7] +// CHECK: %[[T4:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]] +// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>) +// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>) +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32, +// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32) +// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index +// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index +// CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index +// CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index +// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[IDX1]], %[[IDX0]]) +// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[IDX4]], %[[IDX3]], %[[IDX2]]) +// CHECK-DAG: %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]] +// CHECK: %[[T8:.+]] = index_cast %[[T5]] +// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]] +// CHECK: %[[T10:.+]] = index_cast %[[T6]] +// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]] +// CHECK: %[[T12:.+]] = index_cast %[[IDX5]] +// CHECK: %[[T13:.+]] = addi %[[T11]], %[[T12]] + +// ----- + func @reshape_as_producer_projected_permutation( %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32> { @@ -407,6 +584,66 @@ // ----- +func @reshape_as_producer_projected_permutation( + %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32> +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] + : tensor<33x8x?xi32> into tensor<264x?xi32> + %1 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0 : tensor<264x?xi32>) + outs(%shape : tensor<264x?x4xi32>) { + ^bb0(%arg1: i32, %s: i32): // no predecessors + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %idx2 = linalg.index 2 : index + %2 = index_cast %idx0 : index to i32 + %3 = addi %arg1, %2 : i32 + %4 = index_cast %idx1 : index to i32 + %5 = addi %3, %4 : i32 + %6 = index_cast %idx2 : index to i32 + %7 = addi %5, %6 : i32 + linalg.yield %7 : i32 + } -> tensor<264x?x4xi32> + return %1 : tensor<264x?x4xi32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: @reshape_as_producer_projected_permutation +// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32> +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]] : tensor<33x8x?xi32>) +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: i32, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: i32) +// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index +// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index +// CHECK-DAG: %[[T0:.+]] = affine.apply #[[MAP2]](%[[IDX1]], %[[IDX0]]) +// CHECK: %[[T1:.+]] = index_cast %[[T0]] : index to i32 +// CHECK: %[[T2:.+]] = addi %[[ARG1]], %[[T1]] : i32 +// CHECK: %[[T3:.+]] = index_cast %[[IDX2]] : index to i32 +// CHECK: %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32 +// CHECK: %[[T5:.+]] = index_cast %[[IDX3]] : index to i32 +// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32 +// CHECK: linalg.yield %[[T6]] : i32 +// CHECK: %[[RES2:.+]] = linalg.tensor_reshape %[[RES]] +// CHECK-SAME: [#[[MAP3]], #[[MAP4]], #[[MAP5]]] +// CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32> +// CHECK: return %[[RES2]] : tensor<264x?x4xi32> + +// ----- + #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d1, d0)> func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor,