diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -34,7 +34,7 @@ SplitReductionOptions control = controlSplitReductionFn(op); int64_t ratio = control.ratio; - unsigned insertSplitIndex = control.index; + unsigned insertSplitDimension = control.index; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); @@ -50,7 +50,7 @@ op, "Reduction dimension not divisible by split ratio"); if (op.getNumDpsInits() != 1) return b.notifyMatchFailure(op, "More than one output in split reduction"); - if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size()) + if (insertSplitDimension > op.getShape(op.getDpsInitOperand(0)).size()) return b.notifyMatchFailure(op, "Insert dimension position too large " "compared to intermediate tensor size"); @@ -78,19 +78,21 @@ unsigned dim = map.getDimPosition(idx); if (reductionDim == dim) { if (control.innerParallel) { - newShape.push_back(op.getShape(operand)[idx] / ratio); - newShape.push_back(ratio); + newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce + newShape.push_back(ratio); // parallel (insert) + exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1)); + exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); } else { - newShape.push_back(ratio); - newShape.push_back(op.getShape(operand)[idx] / ratio); + newShape.push_back(ratio); // parallel (insert) + newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce + exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); + exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1)); } - exprs.push_back(b.getAffineDimExpr(reductionDim)); - exprs.push_back(b.getAffineDimExpr(reductionDim + 1)); reassociation.push_back({index++, index++}); continue; } newShape.push_back(op.getShape(operand)[idx]); - exprs.push_back(b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1)); + exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); reassociation.push_back({index++}); } newMaps.push_back( @@ -115,19 +117,15 @@ ArrayRef oldShape = op.getShape(op.getDpsInitOperand(0)); SmallVector outputExpr; for (unsigned idx : llvm::seq(0, oldShape.size() + 1)) { - if (insertSplitIndex == idx) { + if (insertSplitDimension == idx) { newOutputShape.push_back(ratio); - if (control.innerParallel) { - outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1)); - } else { - outputExpr.push_back(b.getAffineDimExpr(reductionDim)); - } + outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); } if (idx < oldShape.size()) { newOutputShape.push_back(oldShape[idx]); unsigned dim = oldOutputMap.getDimPosition(idx); outputExpr.push_back( - b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1)); + b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); } } Value emptyOrAllocTensor; @@ -150,11 +148,9 @@ op.getContext())); SmallVector newIteratorTypes; for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) { - if (reductionDim == it.index() && !control.innerParallel) + if (insertSplitDimension == it.index()) newIteratorTypes.push_back(utils::IteratorType::parallel); newIteratorTypes.push_back(it.value()); - if (reductionDim == it.index() && control.innerParallel) - newIteratorTypes.push_back(utils::IteratorType::parallel); } // Create the new op matching the original op with an extra parallel // dimension. @@ -171,7 +167,7 @@ SmallVector reductionIteratorTypes; SmallVector exprs; for (unsigned i : llvm::seq(0, intermRank)) { - if (insertSplitIndex == i) { + if (insertSplitDimension == i) { reductionIteratorTypes.push_back(utils::IteratorType::reduction); } else { exprs.push_back(b.getAffineDimExpr(i)); diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -106,9 +106,9 @@ return %0 : tensor<5x2xf32> } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @generic_split_3d @@ -117,7 +117,7 @@ // CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> -// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel"]} +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} // CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { // CHECK: arith.addf // CHECK: arith.maxf @@ -144,9 +144,9 @@ return %0: tensor<16x32xf32> } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: @matmul_split @@ -156,7 +156,7 @@ // CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] -// CHECK-SAME: , iterator_types = ["parallel", "parallel", "reduction", "parallel"]} +// CHECK-SAME: , iterator_types = ["parallel", "parallel", "parallel", "reduction"]} // CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x64x4xf32>, tensor<64x4x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) { // CHECK: arith.mulf // CHECK: arith.addf @@ -193,9 +193,9 @@ return %red : tensor } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> //CHECK-LABEL: @generic_split_1d @@ -205,7 +205,7 @@ // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> // CHECK: %[[G:.*]] = linalg.generic // CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], -// CHECK: iterator_types = ["reduction", "parallel"]} ins(%[[I1]], %{{.*}} : tensor<8x4xf32>, tensor) outs(%[[F]] : tensor<4xf32>) { +// CHECK: iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<8x4xf32>, tensor) outs(%[[F]] : tensor<4xf32>) { // CHECK: arith.subf // CHECK: math.exp // CHECK: arith.mulf