Index: mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td =================================================================== --- mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -464,6 +464,8 @@ remaining reduction after splitting). - insert_split_dimension: the dimension in the temporary tensor into which the new parallel dimension is inserted. + - inner_parallel: specifies whether the parallel dimension is before or + after the reduction dimension in the splitting op. - use_scaling_algorithm: whether to use a scaling based formulation that does not create an ExpandShapeOp (default: do not use scaling) - use_alloc: whether to use an alloc op to allocate the temporary @@ -587,6 +589,7 @@ let arguments = (ins PDL_Operation:$target, DefaultValuedAttr:$split_factor, DefaultValuedAttr:$insert_split_dimension, + UnitAttr:$inner_parallel, UnitAttr:$use_scaling_algorithm, UnitAttr:$use_alloc); let results = (outs PDL_Operation:$init_or_alloc_op, Index: mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1363,14 +1363,21 @@ } }; -/// Function signature to control reduction splitting. This returns a pair -/// containing a ratio and a dimension index. The ratio is used to split the -/// reduction dimension. The dimension index is used to control where the extra -/// dimension is added to the intermediate tensor shape. If the ratio value is -/// less or equal to 1 then nothing will be done. +/// Split Reduction options +struct SplitReductionOptions { + int64_t ratio = 0; // Ratio used to split the reduction dimension. + // If the ratio is <= 1, nothing will be done. + unsigned index = 0; // Where the extra dimension is added to the + // intermediate tensor shape + bool innerParallel = false; // If the inner dimension after splitting is + // parallel or reduction +}; + +/// Function signature to control reduction splitting. This returns +/// `SplitReductionOptions` // TODO: don't use unsigned unless doing bit manipulation. using ControlSplitReductionFn = - std::function(LinalgOp op)>; + std::function; /// Patterns to apply `splitReduction` below. void populateSplitReductionPattern( Index: mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1001,8 +1001,8 @@ SmallVectorImpl &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { - return std::pair(getSplitFactor(), - getInsertSplitDimension()); + return linalg::SplitReductionOptions{ + int64_t(getSplitFactor()), unsigned(getInsertSplitDimension()), false}; }; SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); Index: mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -84,9 +84,9 @@ OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(op); - std::pair control = controlSplitReductionFn(op); - int64_t ratio = control.first; - unsigned insertSplitDimension = control.second; + SplitReductionOptions control = controlSplitReductionFn(op); + int64_t ratio = control.ratio; + unsigned insertSplitDimension = control.index; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); @@ -125,12 +125,23 @@ for (unsigned idx : llvm::seq(0, map.getNumResults())) { unsigned dim = map.getDimPosition(idx); if (reductionDim == dim) { - newShape.push_back(ratio); - newShape.push_back(op.getShape(operand)[idx] / ratio); + if (control.innerParallel) { + newShape.push_back(op.getShape(operand)[idx] / ratio); + newShape.push_back(ratio); + } else { + newShape.push_back(ratio); + newShape.push_back(op.getShape(operand)[idx] / ratio); + } reassociation.push_back({index++, index++}); - exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); - exprs.push_back( - b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); + if (control.innerParallel) { + exprs.push_back( + b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1)); + exprs.push_back(b.getAffineDimExpr(reductionDim + 1)); + } else { + exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); + exprs.push_back( + b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); + } continue; } newShape.push_back(op.getShape(operand)[idx]); @@ -163,7 +174,11 @@ llvm::seq(0, oldOutputMap.getNumResults() + 1)) { if (idx == insertSplitDimension) { newOutputShape.push_back(ratio); - outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); + if (control.innerParallel) { + outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1)); + } else { + outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); + } continue; } unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1; @@ -192,9 +207,11 @@ op.getContext())); SmallVector newIteratorTypes; for (auto &it : llvm::enumerate(op.iterator_types())) { - if (insertSplitDimension == it.index()) + if (insertSplitDimension == it.index() && !control.innerParallel) newIteratorTypes.push_back(getParallelIteratorTypeName()); newIteratorTypes.push_back(it.value().cast().getValue()); + if (insertSplitDimension == it.index() && control.innerParallel) + newIteratorTypes.push_back(getParallelIteratorTypeName()); } // Create the new op matching the original op with an extra parallel // dimension. @@ -275,9 +292,12 @@ b.setInsertionPoint(op); // Matcher part, enforce preconditions. - std::pair control = controlSplitReductionFn(op); - int64_t splitFactor = control.first; - unsigned insertSplitDimension = control.second; + SplitReductionOptions control = controlSplitReductionFn(op); + if (control.innerParallel) + return b.notifyMatchFailure(op, "innerParallel not supported"); + + int64_t splitFactor = control.ratio; + unsigned insertSplitDimension = control.index; if (splitFactor <= 1) return b.notifyMatchFailure(op, "split factor needs to be greater than 1"); Index: mlir/test/Dialect/Linalg/split_reduction.mlir =================================================================== --- mlir/test/Dialect/Linalg/split_reduction.mlir +++ mlir/test/Dialect/Linalg/split_reduction.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction-inner-parallel -split-input-file | FileCheck %s --check-prefix=INNERPARALLELCHECK func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) @@ -31,6 +32,31 @@ // CHECK: } -> tensor<16x32xf32> // CHECK: return %[[R]] : tensor<16x32xf32> +// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)> +// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// INNERPARALLELCHECK-LABEL: @matmul_split +// INNERPARALLELCHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 +// INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x64x4xf32> +// INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<64x4x32xf32> +// INNERPARALLELCHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32> +// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> +// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] +// INNERPARALLELCHECK-SAME: , iterator_types = ["parallel", "parallel", "reduction", "parallel"]} +// INNERPARALLELCHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x64x4xf32>, tensor<64x4x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) { +// INNERPARALLELCHECK: arith.mulf +// INNERPARALLELCHECK: arith.addf +// INNERPARALLELCHECK: linalg.yield +// INNERPARALLELCHECK: } -> tensor<16x32x4xf32> +// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], +// INNERPARALLELCHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) { +// INNERPARALLELCHECK: arith.addf +// INNERPARALLELCHECK: linalg.yield %{{.*}} : f32 +// INNERPARALLELCHECK: } -> tensor<16x32xf32> +// INNERPARALLELCHECK: return %[[R]] : tensor<16x32xf32> + // ----- func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor, %out: tensor) -> tensor { @@ -73,6 +99,30 @@ // CHECK: } -> tensor // CHECK: return %[[R]] : tensor +// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> +// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)> +// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)> +// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> +//INNERPARALLELCHECK-LABEL: @generic_split_1d +// INNERPARALLELCHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 +// INNERPARALLELCHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<8x4xf32> +// INNERPARALLELCHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32> +// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> +// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic +// INNERPARALLELCHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], +// INNERPARALLELCHECK: iterator_types = ["reduction", "parallel"]} ins(%[[I1]], %{{.*}} : tensor<8x4xf32>, tensor) outs(%[[F]] : tensor<4xf32>) { +// INNERPARALLELCHECK: arith.subf +// INNERPARALLELCHECK: math.exp +// INNERPARALLELCHECK: arith.mulf +// INNERPARALLELCHECK: linalg.yield +// INNERPARALLELCHECK: } -> tensor<4xf32> +// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor) { +// INNERPARALLELCHECK: arith.mulf +// INNERPARALLELCHECK: linalg.yield +// INNERPARALLELCHECK: } -> tensor +// INNERPARALLELCHECK: return %[[R]] : tensor + // ----- func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>) @@ -117,3 +167,27 @@ // CHECK: linalg.yield // CHECK: } -> tensor<5x2xf32> // CHECK: return %[[R]] : tensor<5x2xf32> + +// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)> +// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)> +// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> +// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// INNERPARALLELCHECK-LABEL: func @generic_split_3d +// INNERPARALLELCHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 +// INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32> +// INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32> +// INNERPARALLELCHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32> +// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> +// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} +// INNERPARALLELCHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { +// INNERPARALLELCHECK: arith.addf +// INNERPARALLELCHECK: arith.maxf +// INNERPARALLELCHECK: linalg.yield +// INNERPARALLELCHECK: } -> tensor<5x2x4xf32> +// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} +// INNERPARALLELCHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) { +// INNERPARALLELCHECK: arith.maxf +// INNERPARALLELCHECK: linalg.yield +// INNERPARALLELCHECK: } -> tensor<5x2xf32> +// INNERPARALLELCHECK: return %[[R]] : tensor<5x2xf32> Index: mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -103,6 +103,10 @@ *this, "test-split-reduction", llvm::cl::desc("Test split reduction transformation"), llvm::cl::init(false)}; + Option testSplitReductionInnerParallel{ + *this, "test-split-reduction-inner-parallel", + llvm::cl::desc("Test split reduction with inner parallel transformation"), + llvm::cl::init(false)}; ListOption peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern")}; @@ -499,7 +503,21 @@ patterns, [](LinalgOp op) { unsigned insertDimIndex = op.getNumLoops() - 1; - return std::make_pair(4, insertDimIndex); + return SplitReductionOptions{4, insertDimIndex, false}; + }, + LinalgTransformationFilter( + ArrayRef{}, + StringAttr::get(funcOp.getContext(), "SPLIT"))); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + +static void applySplitReductionInnerParallel(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + linalg::populateSplitReductionPattern( + patterns, + [](LinalgOp op) { + unsigned insertDimIndex = op.getNumLoops() - 1; + return SplitReductionOptions{4, insertDimIndex, true}; }, LinalgTransformationFilter( ArrayRef{}, @@ -560,6 +578,8 @@ /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); if (testSplitReduction) return applySplitReduction(getOperation()); + if (testSplitReductionInnerParallel) + return applySplitReductionInnerParallel(getOperation()); if (testBubbleUpExtractSliceOpPattern) return applyBubbleUpExtractSliceOpPattern(getOperation()); if (testSwapExtractSliceWithFill)