Index: mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -34,7 +34,7 @@ SplitReductionOptions control = controlSplitReductionFn(op); int64_t ratio = control.ratio; - unsigned insertSplitDimension = control.index; + unsigned insertSplitIndex = control.index; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); @@ -45,10 +45,14 @@ SmallVector loopRanges = op.getStaticLoopRanges(); int64_t reductionDimSize = loopRanges[reductionDim]; if (reductionDimSize == ShapedType::kDynamicSize || - reductionDimSize % ratio != 0 || - insertSplitDimension >= loopRanges.size()) + reductionDimSize % ratio != 0) return b.notifyMatchFailure( op, "Reduction dimension not divisible by split ratio"); + if (op.getNumDpsInits() != 1) + return b.notifyMatchFailure(op, "More than one output in split reduction"); + if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size()) + return b.notifyMatchFailure(op, "Insert dimension position too large " + "compared to intermediate tensor size"); SmallVector combinerOps; if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || @@ -80,25 +84,13 @@ newShape.push_back(ratio); newShape.push_back(op.getShape(operand)[idx] / ratio); } + exprs.push_back(b.getAffineDimExpr(reductionDim)); + exprs.push_back(b.getAffineDimExpr(reductionDim + 1)); reassociation.push_back({index++, index++}); - if (control.innerParallel) { - exprs.push_back(b.getAffineDimExpr(reductionDim)); - exprs.push_back(b.getAffineDimExpr(reductionDim + 1)); - } else { - exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); - exprs.push_back( - b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); - } continue; } newShape.push_back(op.getShape(operand)[idx]); - if (control.innerParallel) { - exprs.push_back( - b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1)); - } else { - exprs.push_back( - b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); - } + exprs.push_back(b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1)); reassociation.push_back({index++}); } newMaps.push_back( @@ -122,26 +114,20 @@ AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0)); ArrayRef oldShape = op.getShape(op.getDpsInitOperand(0)); SmallVector outputExpr; - for (unsigned idx : - llvm::seq(0, oldOutputMap.getNumResults() + 1)) { - if (idx == insertSplitDimension) { + for (unsigned idx : llvm::seq(0, oldShape.size() + 1)) { + if (insertSplitIndex == idx) { newOutputShape.push_back(ratio); if (control.innerParallel) { outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1)); } else { - outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); + outputExpr.push_back(b.getAffineDimExpr(reductionDim)); } - continue; } - unsigned oldIdx = idx < insertSplitDimension ? idx : idx - 1; - newOutputShape.push_back(oldShape[oldIdx]); - unsigned dim = oldOutputMap.getDimPosition(oldIdx); - if (control.innerParallel) { - outputExpr.push_back( - b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1)); - } else { + if (idx < oldShape.size()) { + newOutputShape.push_back(oldShape[idx]); + unsigned dim = oldOutputMap.getDimPosition(idx); outputExpr.push_back( - b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); + b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1)); } } Value emptyOrAllocTensor; @@ -164,10 +150,10 @@ op.getContext())); SmallVector newIteratorTypes; for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) { - if (insertSplitDimension == it.index() && !control.innerParallel) + if (reductionDim == it.index() && !control.innerParallel) newIteratorTypes.push_back(utils::IteratorType::parallel); newIteratorTypes.push_back(it.value()); - if (insertSplitDimension == it.index() && control.innerParallel) + if (reductionDim == it.index() && control.innerParallel) newIteratorTypes.push_back(utils::IteratorType::parallel); } // Create the new op matching the original op with an extra parallel @@ -185,7 +171,7 @@ SmallVector reductionIteratorTypes; SmallVector exprs; for (unsigned i : llvm::seq(0, intermRank)) { - if (insertSplitDimension == i) { + if (insertSplitIndex == i) { reductionIteratorTypes.push_back(utils::IteratorType::reduction); } else { exprs.push_back(b.getAffineDimExpr(i)); Index: mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -106,9 +106,9 @@ return %0 : tensor<5x2xf32> } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @generic_split_3d @@ -117,7 +117,7 @@ // CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> -// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel"]} // CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { // CHECK: arith.addf // CHECK: arith.maxf