diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -117,4 +117,35 @@ let hasCustomAssemblyFormat = 1; } +def VectorizeOp : Op { + let description = [{ + Indicates that the given `target` op all the ops it contains should be + vectorized with the configuration specified by the attributes of this op. + This vectorization only handles structured ops that operate on shaped types + and does not vectorize loops or straight-line. Internally, it applies a + set of rewrite patterns, some of which enable vectorization and some of + which clean up the results. Therefore, it can only be applied to an op with + the "isolated from above property". If finer granularity is required, it can + be achieved by outlining the target part of the payload IR into, e.g., a + function, performing the transformation, and inlining it back. This + transformation only fails if the entire pattern rewriting failed, i.e., it + does **not** fail when no ops were vectorized. + + Note that this transformation is invalidating the handles to any payload IR + operation that is contained inside the vectoriztaion target. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$vectorize_padding); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr applyToOne(::mlir::Operation *target); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -16,6 +16,32 @@ include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Gets handles to the closest isolated-from-above parents"; + let description = [{ + The handles defined by this Transform op correspond to the closest isolated + from above ancestor of the Payload IR operations associated with its + operand. If any of the given Payload IR ops has no such parent (unlikely as + there usually is a top-level ModuleOp), the transformation is considered to + have failed. + + Ancestor ops follow the same order as the ops assocaited with the + operand, except for potential duplicates (multiple Payload IR ops assocaited + with the operand have the same parent) for which the ancestor will only be + listed once for the first time it occurs. For example, given the list + "(childof(A), childof(B), childof(B), childof(A), childof(B))", the + resulting list will be just "(A, B)". Note that no other semantic ordering + is applied, e.g., "B" may itself be a parent of "A". This may have an impact + on the further transformation applied to the handle produced here. + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$parent); + let assemblyFormat = "$target attr-dict"; +} + def PDLMatchOp : TransformDialectOp<"pdl_match", [DeclareOpInterfaceMethods]> { let summary = "Finds ops that match the named PDL pattern"; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -15,4 +15,5 @@ MLIRPDL MLIRSideEffectInterfaces MLIRTransformDialect + MLIRVector ) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Parser/Parser.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -338,6 +339,40 @@ effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); } +//===----------------------------------------------------------------------===// +// VectorizeOp +//===----------------------------------------------------------------------===// + +FailureOr VectorizeOp::applyToOne(Operation *target) { + if (!target->hasTrait()) { + InFlightDiagnostic diag = emitOpError() + << "applies only to isolated-from-above targets"; + diag.attachNote(target->getLoc()) << "non-isolated target"; + return diag; + } + + MLIRContext *ctx = getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + vector::populateVectorReductionToContractPatterns(patterns); + patterns.add(ctx, + /*benefit=*/2); + vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); + vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); + if (getVectorizePadding()) + linalg::populatePadOpVectorizationPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) { + InFlightDiagnostic diag = emitError() << "failed to apply"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + return target; +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// @@ -352,6 +387,7 @@ LinalgTransformDialectExtension() { declareDependentDialect(); declareDependentDialect(); + declareDependentDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -114,6 +114,39 @@ } } // namespace +//===----------------------------------------------------------------------===// +// GetClosestIsolatedParentOp +//===----------------------------------------------------------------------===// + +LogicalResult transform::GetClosestIsolatedParentOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + SetVector parents; + for (Operation *target : state.getPayloadOps(getTarget())) { + Operation *parent = + target->getParentWithTrait(); + if (!parent) { + InFlightDiagnostic diag = + emitError() << "could not find an isolated-from-above parent op"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + parents.insert(parent); + } + results.set(getResult().cast(), parents.getArrayRef()); + return success(); +} + +void transform::GetClosestIsolatedParentOp::getEffects( + SmallVectorImpl &effects) { + effects.emplace_back(MemoryEffects::Read::get(), getTarget(), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Allocate::get(), getParent(), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), getParent(), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); +} + //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -0,0 +1,182 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @vectorize_matmul +// CHECK-SAME: %[[A:.*]]: tensor<24x12xf32> +// CHECK-SAME: %[[B:.*]]: tensor<12x25xf32> +// CHECK-SAME: %[[C:.*]]: tensor<24x25xf32> +func.func @vectorize_matmul(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + // CHECK: %[[vA:.+]] = vector.transfer_read %[[A]] + // CHECK: %[[vB:.+]] = vector.transfer_read %[[B]] + // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]] + // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]] + // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]] + // CHECK: vector.transfer_write %[[vS]], %[[C]] + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = get_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + +// ----- + +#map0 = affine_map<()[s0] -> (-s0 + 12, 7)> +#map1 = affine_map<()[s0] -> (-s0 + 7)> + +// CHECK-LABEL: @vectorize_keep_pad +// CHECK-SAME: %[[C:[a-zA-Z0-9_]+]]: tensor<24x25xf32> +func.func @vectorize_keep_pad( + %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, %arg3: index, %arg4: index, + %arg5: index) -> tensor<24x25xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = affine.min #map0()[%arg5] + %1 = tensor.extract_slice %arg0[%arg3, %arg5] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + %2 = tensor.extract_slice %arg1[%arg5, %arg4] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + %3 = tensor.extract_slice %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + %4 = affine.apply #map1()[%0] + // CHECK: %[[pA:.*]] = tensor.pad + %5 = tensor.pad %1 nofold low[%c0, %c0] high[%c0, %4] { + ^bb0(%arg6: index, %arg7: index): + tensor.yield %cst : f32 + } : tensor<4x?xf32> to tensor<4x7xf32> + %6 = affine.apply #map1()[%0] + // CHECK: %[[pB:.*]] = tensor.pad + %7 = tensor.pad %2 nofold low[%c0, %c0] high[%6, %c0] { + ^bb0(%arg6: index, %arg7: index): + tensor.yield %cst : f32 + } : tensor to tensor<7x5xf32> + // CHECK: %[[vA:.+]] = vector.transfer_read %[[pA]] + // CHECK: %[[vB:.+]] = vector.transfer_read %[[pB]] + // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]] + // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]] + // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]] + // CHECK: vector.transfer_write %[[vS]], %[[C]] + %8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + %9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + return %9 : tensor<24x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = get_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + +// ----- + +#map0 = affine_map<()[s0] -> (-s0 + 12, 7)> +#map1 = affine_map<()[s0] -> (-s0 + 7)> + +// CHECK-LABEL: @vectorize_pad +// CHECK-SAME: %[[A:.+]]: tensor<24x12xf32> +// CHECK-SAME: %[[B:.+]]: tensor<12x25xf32> +// CHECK-SAME: %[[C:.+]]: tensor<24x25xf32> +func.func @vectorize_pad( + %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, %arg3: index, %arg4: index, + %arg5: index) -> tensor<24x25xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = affine.min #map0()[%arg5] + // CHECK: %[[sA:.+]] = tensor.extract_slice %[[A]] + // CHECK: %[[sB:.+]] = tensor.extract_slice %[[B]] + %1 = tensor.extract_slice %arg0[%arg3, %arg5] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + %2 = tensor.extract_slice %arg1[%arg5, %arg4] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + %3 = tensor.extract_slice %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + // CHECK: %[[vA:.+]] = vector.transfer_read %[[sA]] + %4 = affine.apply #map1()[%0] + %5 = tensor.pad %1 nofold low[%c0, %c0] high[%c0, %4] { + ^bb0(%arg6: index, %arg7: index): + tensor.yield %cst : f32 + } : tensor<4x?xf32> to tensor<4x7xf32> + %6 = affine.apply #map1()[%0] + // CHECK: %[[vB:.+]] = vector.transfer_read %[[sB]] + %7 = tensor.pad %2 nofold low[%c0, %c0] high[%6, %c0] { + ^bb0(%arg6: index, %arg7: index): + tensor.yield %cst : f32 + } : tensor to tensor<7x5xf32> + // CHECK: %[[vC:.+]] = vector.transfer_read %[[C]] + // CHECK: %[[vR:.+]] = vector.contract {{.*}} %[[vA]], %[[vB]] + // CHECK: %[[vS:.+]] = arith.addf %[[vR]], %[[vC]] + // CHECK: vector.transfer_write %[[vS]], %[[C]] + %8 = linalg.matmul ins(%5, %7 : tensor<4x7xf32>, tensor<7x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + %9 = tensor.insert_slice %8 into %arg2[%arg3, %arg4] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + return %9 : tensor<24x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = get_isolated_parent %0 + %2 = transform.structured.vectorize %1 {vectorize_padding = true} + } +} + +// ----- + +func.func @vectorize(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + // expected-note @below {{non-isolated target}} + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + // expected-error @below {{applies only to isolated-from-above targets}} + %2 = transform.structured.vectorize %0 + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -97,3 +97,34 @@ // expected-remark @below {{matched}} "test.some_op"() : () -> () +// ----- + +// expected-remark @below {{parent function}} +func.func @foo() { + %0 = arith.constant 0 : i32 + return +} + +// expected-remark @below {{parent function}} +func.func @bar() { + %0 = arith.constant 0 : i32 + %1 = arith.constant 1 : i32 + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @const : benefit(1) { + %r = pdl.types + %0 = pdl.operation "arith.constant" -> (%r : !pdl.range) + pdl.rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %f = pdl_match @const in %arg1 + // CHECK: %{{.+}} = get_isolated_parent %{{.+}} + %m = get_isolated_parent %f + test_print_remark_at_operand %m, "parent function" + } +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7379,6 +7379,7 @@ ":Parser", ":SideEffectInterfaces", ":TransformDialect", + ":TransformUtils", "//llvm:Support", ], )