diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -26,6 +26,38 @@ using namespace mlir; using namespace mlir::linalg; +/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of +/// the `producer` to use in the fused operation given the indexing map of the +/// result of the producer in the consumer. +static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( + OpOperand *producerOpOperand, AffineMap producerResultIndexMap, + AffineMap fusedConsumerArgIndexMap) { + // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map + // from consumer loop -> consumer arg tensor index/producer result tensor + // index. The fused loop is same as the consumer loop. For each producer arg + // the indexing map to be computed is a map from consumer loop -> producer + // arg tensor index. + // producerResultIndexMap is a map from producer loop -> tensor index. + // Compute the inverse to get map from tensor index -> producer loop. + // The inverse is a map from producer result tensor index -> producer loop. + AffineMap invProducerResultIndexMap = + inversePermutation(producerResultIndexMap); + assert(invProducerResultIndexMap && + "expected producer result indexig map to be invertible"); + + LinalgOp producer = cast(producerOpOperand->getOwner()); + // argMap is a map from producer loop -> producer arg tensor index. + AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); + + // Compose argMap with invProducerResultIndexMap to get a map from + // producer result tensor index -> producer arg tensor index. + AffineMap t1 = argMap.compose(invProducerResultIndexMap); + + // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from + // consumer loop/ fused loop -> producer arg tensor index. + return t1.compose(fusedConsumerArgIndexMap); +} + /// Conditions for elementwise fusion of generic operations. static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, OpOperand *consumerOpOperand) { @@ -57,39 +89,42 @@ // verify it is a permutation. AffineMap producerResultIndexMap = producer.getTiedIndexingMap(producer.getOutputOperand(0)); - return producerResultIndexMap.isPermutation(); -} + if (!producerResultIndexMap.isPermutation()) + return false; -/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of -/// the `producer` to use in the fused operation given the indexing map of the -/// result of the producer in the consumer. -static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - OpOperand *producerOpOperand, AffineMap producerResultIndexMap, - AffineMap fusedConsumerArgIndexMap) { - // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map - // from consumer loop -> consumer arg tensor index/producer result tensor - // index. The fused loop is same as the consumer loop. For each producer arg - // the indexing map to be computed is a map from consumer loop -> producer - // arg tensor index. - // producerResultIndexMap is a map from producer loop -> tensor index. - // Compute the inverse to get map from tensor index -> producer loop. - // The inverse is a map from producer result tensor index -> producer loop. - AffineMap invProducerResultIndexMap = - inversePermutation(producerResultIndexMap); - assert(invProducerResultIndexMap && - "expected producer result indexig map to be invertible"); + // Ensure that the fusion does not remove size information required to + // get the loop bounds. For non-reduction generics, this is trivially the + // case due to the output operand. For reductions, we need to check that after + // the fusion, each loop dimension has at least one input that defines it. + if ((consumer.getNumReductionLoops())) { + llvm::BitVector coveredDims(consumer.getNumLoops(), false); + + auto addToCoveredDims = [&](AffineMap map) { + for (auto result : map.getResults()) + if (auto dimExpr = result.dyn_cast()) + coveredDims[dimExpr.getPosition()] = true; + }; - LinalgOp producer = cast(producerOpOperand->getOwner()); - // argMap is a map from producer loop -> producer arg tensor index. - AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); + for (auto pair : + llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) { + Value operand = std::get<0>(pair); + if (operand == consumerOpOperand->get()) + continue; + AffineMap operandMap = std::get<1>(pair); + addToCoveredDims(operandMap); + } - // Compose argMap with invProducerResultIndexMap to get a map from - // producer result tensor index -> producer arg tensor index. - AffineMap t1 = argMap.compose(invProducerResultIndexMap); + for (OpOperand *operand : producer.getInputOperands()) { + AffineMap newIndexingMap = + getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( + operand, producerResultIndexMap, consumerIndexMap); + addToCoveredDims(newIndexingMap); + } + if (!coveredDims.all()) + return false; + } - // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from - // consumer loop/ fused loop -> producer arg tensor index. - return t1.compose(fusedConsumerArgIndexMap); + return true; } /// Generate the region of the fused tensor operation. The region of the fused diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -907,3 +907,35 @@ } -> tensor<3x2xf32> return %1 : tensor<3x2xf32> } + +// ----- + +#map0 = affine_map<(d0, d1) -> ()> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +#map2 = affine_map<(d0, d1) -> (d1, d0)> +#map3 = affine_map<(d0, d1) -> (d0)> +// CHECK-LABEL: @no_fusion_missing_reduction_shape +func @no_fusion_missing_reduction_shape(%arg0: tensor, %arg1: index) -> tensor { + %cst = arith.constant 0xFF800000 : f32 + %4 = linalg.init_tensor [%arg1, %arg1] : tensor + // CHECK: linalg.generic + %5 = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel"] + } ins(%arg0 : tensor) outs(%4 : tensor) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor + %6 = linalg.init_tensor [%arg1] : tensor + %7 = linalg.fill(%cst, %6) : f32, tensor -> tensor + // CHECK: linalg.generic + %8 = linalg.generic { + indexing_maps = [#map2, #map3], + iterator_types = ["parallel", "reduction"] + } ins(%5 : tensor) outs(%7 : tensor) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + %9 = maxf %arg2, %arg3 : f32 + linalg.yield %9 : f32 + } -> tensor + return %8 : tensor +}