diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1050,7 +1050,7 @@ ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), unsigned(getInsertSplitDimension()), - /*innerParallel=*/false}; + bool(getInnerParallel())}; }; SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -1,24 +1,275 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s +// RUN: mlir-opt --split-input-file --test-transform-dialect-interpreter %s | FileCheck %s -// CHECK-LABEL: func.func @matmul_split func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: @matmul_split +// CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: , iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) { +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor<16x32x4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) { +// CHECK: arith.addf +// CHECK: linalg.yield %{{.*}} : f32 +// CHECK: } -> tensor<16x32xf32> +// CHECK: return %[[R]] : tensor<16x32xf32> + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2} +} + +// ----- - // CHECK: linalg.generic - // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] - // CHECK-SAME: ins(%{{[a-zA-Z0-9_]*}}, %{{[a-zA-Z0-9_]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>) - // CHECK-SAME: outs(%{{[a-zA-Z0-9_]*}} : tensor<16x32x4xf32>) { +func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()>, + affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0, %arg1 : tensor<32xf32>, tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): + %40 = arith.subf %arg7, %arg8 : f32 + %41 = math.exp %40 : f32 + %42 = arith.mulf %41, %arg9 : f32 + linalg.yield %42 : f32 + } -> tensor + return %red : tensor +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> +//CHECK-LABEL: @generic_split_1d +// CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[G:.*]] = linalg.generic +// CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], +// CHECK: iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<4x8xf32>, tensor) outs(%[[F]] : tensor<4xf32>) { +// CHECK: arith.subf +// CHECK: math.exp +// CHECK: arith.mulf +// CHECK: linalg.yield +// CHECK: } -> tensor<4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor) { +// CHECK: arith.mulf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: return %[[R]] : tensor + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0} +} - // CHECK: linalg.generic - // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] - // CHECK-SAME: ins(%{{[a-zA-Z0-9_]*}} : tensor<16x32x4xf32>) - // CHECK-SAME: outs(%{{[a-zA-Z0-9_]*}} : tensor<16x32xf32>) { +// ----- + +func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>) + -> tensor<5x2xf32> +{ + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d0)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %3 = arith.addf %arg0, %arg1 : f32 + %4 = arith.maxf %3, %arg2 : f32 + linalg.yield %4 : f32 + } -> tensor<5x2xf32> + return %0 : tensor<5x2xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @generic_split_3d +// CHECK-DAG: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { +// CHECK: arith.addf +// CHECK: arith.maxf +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2x4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) { +// CHECK: arith.maxf +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2xf32> +// CHECK: return %[[R]] : tensor<5x2xf32> + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2} +} + +// ----- + +func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> } +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: @matmul_split +// CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x64x4xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<64x4x32xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: , iterator_types = ["parallel", "parallel", "reduction", "parallel"]} +// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x64x4xf32>, tensor<64x4x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) { +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor<16x32x4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) { +// CHECK: arith.addf +// CHECK: linalg.yield %{{.*}} : f32 +// CHECK: } -> tensor<16x32xf32> +// CHECK: return %[[R]] : tensor<16x32xf32> + transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2} + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, inner_parallel} +} + +// ----- + +func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()>, + affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0, %arg1 : tensor<32xf32>, tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): + %40 = arith.subf %arg7, %arg8 : f32 + %41 = math.exp %40 : f32 + %42 = arith.mulf %41, %arg9 : f32 + linalg.yield %42 : f32 + } -> tensor + return %red : tensor +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> +//CHECK-LABEL: @generic_split_1d +// CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<8x4xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[G:.*]] = linalg.generic +// CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], +// CHECK: iterator_types = ["reduction", "parallel"]} ins(%[[I1]], %{{.*}} : tensor<8x4xf32>, tensor) outs(%[[F]] : tensor<4xf32>) { +// CHECK: arith.subf +// CHECK: math.exp +// CHECK: arith.mulf +// CHECK: linalg.yield +// CHECK: } -> tensor<4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor) { +// CHECK: arith.mulf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: return %[[R]] : tensor + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel} +} + +// ----- + +func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>) + -> tensor<5x2xf32> +{ + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d0)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %3 = arith.addf %arg0, %arg1 : f32 + %4 = arith.maxf %3, %arg2 : f32 + linalg.yield %4 : f32 + } -> tensor<5x2xf32> + return %0 : tensor<5x2xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @generic_split_3d +// CHECK-DAG: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { +// CHECK: arith.addf +// CHECK: arith.maxf +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2x4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) { +// CHECK: arith.maxf +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2xf32> +// CHECK: return %[[R]] : tensor<5x2xf32> + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, inner_parallel} }