diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2074,6 +2074,40 @@ }]; } +def VectorizeOneOp : Op, + TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Vectorize the target ops, which must be Linalg ops. + + Use this Op to excercise the vectorizer. Contrary to `VectorizeOp`, it does + not apply any rewrite patterns, so that the output can easily be mapped to + the transformation within the vectorizer. + + Typically this operator should be applied to linalg operations that have + already been tiled to the appropriate sizes. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$vectorize_nd_extract); + + let results = (outs); + let assemblyFormat = [{ + $target + attr-dict + `:` type($target) + }]; + + let extraClassDeclaration = [{ + // TODO: applyToOne. + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + }]; +} + def MaskedVectorizeOp : Op, TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> { @@ -2089,7 +2123,7 @@ counterpart iteration space sizes. Typically this operator should be applied to linalg operations that have - already be tiled to the appropriate sizes. + already been tiled to the appropriate sizes. #### Return modes: diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3252,6 +3252,41 @@ return success(); } +//===----------------------------------------------------------------------===// +// VectorizeOneOp +//===----------------------------------------------------------------------===// +DiagnosedSilenceableFailure transform::VectorizeOneOp::apply( + transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &transformResults, + mlir::transform::TransformState &state) { + auto targets = state.getPayloadOps(getTarget()); + if (std::empty(targets)) + return DiagnosedSilenceableFailure::success(); + + SmallVector vectorSizes; + + for (Operation *target : targets) { + if (!isa(target)) { + return mlir::emitSilenceableFailure(target->getLoc()) + << "Unsupported Op, cannot vectorize"; + } + + if (failed(vectorize(rewriter, target, /*inputVectorSizes=*/{}, + /*scalableVecDims=*/{}, getVectorizeNdExtract()))) { + return mlir::emitSilenceableFailure(target->getLoc()) + << "Attempted to vectorize, but failed"; + } + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::VectorizeOneOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // HoistRedundantVectorTransfersOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -13,8 +13,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.dot"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize_one_op %0 : !transform.any_op } // ----- diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -35,13 +35,12 @@ } // ----- - -#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { %c0 = arith.constant 1 : index %c1 = arith.constant 2 : index %2 = linalg.generic { - indexing_maps = [#map1], + indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"] } outs(%arg2 : tensor<1x1x3xf32>) { ^bb0(%arg4: f32): @@ -51,23 +50,22 @@ return %2 : tensor<1x1x3xf32> } +// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (0, 0, 0)> // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx( // CHECK-SAME: %[[ARG_0:.*]]: tensor<3x3xf32>, // CHECK-SAME: %[[ARG_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]]{{\[}}%[[C1]], %[[C2]]] : tensor<3x3xf32> -// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32> -// CHECK: %[[VAL_7:.*]] = vector.transfer_write %[[BCAST]], %[[ARG_1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32> -// CHECK: return %[[VAL_7]] : tensor<1x1x3xf32> -// CHECK: } +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], %[[C0_f32]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32> +// CHECK: %[[C0_4:.*]] = arith.constant 0 : index +// CHECK: vector.transfer_write %[[READ]], %[[ARG_1]][%[[C0_4]], %[[C0_4]], %[[C0_4]]] : vector<1x1x3xf32>, tensor<1x1x3xf32> transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize_one_op %0 { vectorize_nd_extract } : !transform.any_op } // -----