diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -35,6 +35,7 @@ SplitReductionOptions control = controlSplitReductionFn(op); int64_t ratio = control.ratio; unsigned insertSplitIndex = control.index; + unsigned insertSplitDimension = control.index; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); @@ -42,6 +43,9 @@ op.getReductionDims(dims); assert(dims.size() == 1); unsigned reductionDim = dims[0]; + if (control.innerParallel) { + insertSplitDimension = reductionDim + 1; + } SmallVector loopRanges = op.getStaticLoopRanges(); int64_t reductionDimSize = loopRanges[reductionDim]; if (reductionDimSize == ShapedType::kDynamic || @@ -78,19 +82,21 @@ unsigned dim = map.getDimPosition(idx); if (reductionDim == dim) { if (control.innerParallel) { - newShape.push_back(op.getShape(operand)[idx] / ratio); - newShape.push_back(ratio); + newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce + newShape.push_back(ratio); // parallel (insert) + exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1)); + exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); } else { - newShape.push_back(ratio); - newShape.push_back(op.getShape(operand)[idx] / ratio); + newShape.push_back(ratio); // parallel (insert) + newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce + exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); + exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1)); } - exprs.push_back(b.getAffineDimExpr(reductionDim)); - exprs.push_back(b.getAffineDimExpr(reductionDim + 1)); reassociation.push_back({index++, index++}); continue; } newShape.push_back(op.getShape(operand)[idx]); - exprs.push_back(b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1)); + exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); reassociation.push_back({index++}); } newMaps.push_back( @@ -117,17 +123,13 @@ for (unsigned idx : llvm::seq(0, oldShape.size() + 1)) { if (insertSplitIndex == idx) { newOutputShape.push_back(ratio); - if (control.innerParallel) { - outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1)); - } else { - outputExpr.push_back(b.getAffineDimExpr(reductionDim)); - } + outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); } if (idx < oldShape.size()) { newOutputShape.push_back(oldShape[idx]); unsigned dim = oldOutputMap.getDimPosition(idx); outputExpr.push_back( - b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1)); + b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); } } Value emptyOrAllocTensor; @@ -150,11 +152,12 @@ op.getContext())); SmallVector newIteratorTypes; for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) { - if (reductionDim == it.index() && !control.innerParallel) + if (insertSplitDimension == it.index()) newIteratorTypes.push_back(utils::IteratorType::parallel); newIteratorTypes.push_back(it.value()); - if (reductionDim == it.index() && control.innerParallel) - newIteratorTypes.push_back(utils::IteratorType::parallel); + } + if (insertSplitDimension == op.getIteratorTypesArray().size()) { + newIteratorTypes.push_back(utils::IteratorType::parallel); } // Create the new op matching the original op with an extra parallel // dimension. diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -106,9 +106,9 @@ return %0 : tensor<5x2xf32> } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @generic_split_3d @@ -117,7 +117,7 @@ // CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> -// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel"]} +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} // CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { // CHECK: arith.addf // CHECK: arith.maxf