diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -348,7 +348,7 @@ does **not** fail when no ops were vectorized. Note that this transformation is invalidating the handles to any payload IR - operation that is contained inside the vectoriztaion target. + operation that is contained inside the vectorization target. }]; let arguments = (ins PDL_Operation:$target, diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -794,6 +794,24 @@ } return DiagnosedSilenceableFailure::success(); } + +/// Helper function to transform M ops with N results into N results of M ops. +static inline SmallVector> +transposeResults(const SmallVector, 1> &m) { + SmallVector> res; + if (m.empty()) + return res; + int64_t rows = m.size(), cols = m[0].size(); + for (int64_t j = 0; j < cols; ++j) + res.push_back(SmallVector(rows, nullptr)); + for (int64_t i = 0; i < rows; ++i) { + assert(static_cast(m[i].size()) == cols); + for (int64_t j = 0; j < cols; ++j) { + res[j][i] = m[i][j]; + } + } + return res; +} } // namespace detail } // namespace transform } // namespace mlir @@ -815,27 +833,51 @@ }); if (!result.succeeded()) return result; - for (const SmallVector &oneTargetResults : results) { - if (OpTy::template hasTrait()) - continue; - if (OpTy::template hasTrait()) { - transformResults.set( - this->getOperation()->getResult(0).template cast(), - oneTargetResults); - continue; - } - if (this->getOperation()->getNumResults() != oneTargetResults.size()) { - Diagnostic diag(this->getOperation()->getLoc(), - DiagnosticSeverity::Error); - diag << "unexpected number of results (got " << oneTargetResults.size() - << " expected " << this->getOperation()->getNumResults() << ")"; - return DiagnosedSilenceableFailure::silencableFailure(std::move(diag)); - } - for (const auto &it : - llvm::zip(this->getOperation()->getResults(), oneTargetResults)) { - transformResults.set(std::get<0>(it).template cast(), - std::get<1>(it)); - } + if (results.empty()) + return DiagnosedSilenceableFailure::success(); + + // Ensure all applications return the same number of results. + // Variadic cases are much trickier to handle in a generic fashion. + int64_t nRes = results[0].size(); + if (llvm::any_of(results, [&](const auto &r) { + return static_cast(r.size()) != nRes; + })) { + return static_cast(this)->emitSilenceableError() + << "expected all applications of " << OpTy::getOperationName() + << " to produce " << nRes + << " results.\n If you need variadic results, consider using a " + "generic `apply` instead of the specialized `applyToOne`"; + } + // Ensure the number of results agrees with what the transform op expects. + if (this->getOperation()->getNumResults() != nRes) { + InFlightDiagnostic diag = static_cast(this)->emitError() + << "unexpected number of results (got " << nRes + << " expected " + << this->getOperation()->getNumResults() << ")"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + + // If no results, bail early. + if (OpTy::template hasTrait()) + return DiagnosedSilenceableFailure::success(); + + // Perform transposition of M applications producing N results each into N + // results for each of the M applications. + SmallVector> transposedResults = + detail::transposeResults(results); + // Single result applies to M ops produces one single M-result. + if (OpTy::template hasTrait()) { + assert(transposedResults.size() == 1 && "Expected single result"); + transformResults.set( + this->getOperation()->getResult(0).template cast(), + transposedResults[0]); + return DiagnosedSilenceableFailure::success(); + } + // M ops, N results each. + for (const auto &it : + llvm::zip(this->getOperation()->getResults(), transposedResults)) { + transformResults.set(std::get<0>(it).template cast(), + std::get<1>(it)); } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Transform/selective-targeting.mlir b/mlir/test/Dialect/Transform/selective-targeting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/selective-targeting.mlir @@ -0,0 +1,154 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @matmul_tensors_1( +func.func @matmul_tensors_1( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, + %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) + -> tensor<128x128xf32> { + // This operation is marked for tiling only. + // CHECK-COUNT-3: scf.for + // CHECK-COUNT-3: tensor.extract_slice + // CHECK: linalg.matmul + // CHECK-SAME: -> tensor<4x4xf32> + %0 = linalg.matmul { test.attrA } + ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + func.return %0 : tensor<128x128xf32> +} + +func.func @matmul_tensors_2( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, + %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) + -> tensor<128x128xf32> { + // This operation is marked f + // This operation is marked for tiling and vectorization. + // CHECK-COUNT-3: scf.for + // CHECK-COUNT-3: vector.transfer_read + // CHECK: vector.contract + // CHECK-NOT: linalg.matmul + // CHECK: vector.transfer_write + %0 = linalg.matmul { test.attrA, test.attrC } + ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + func.return %0 : tensor<128x128xf32> +} + +func.func @matmul_tensors_3( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, + %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) + -> tensor<128x128xf32> { + // This operation is marked for vectorization only. + // CHECK-NOT: scf.for + // CHECK-COUNT-3: vector.transfer_read + // CHECK: vector.contract + // CHECK-SAME: into vector<128x128xf32> + // CHECK: vector.transfer_write + %0 = linalg.matmul { test.attrC } + ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + func.return %0 : tensor<128x128xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // Match matmul operations inside @matmul_tensors with test.attrA set. + pdl.pattern @pdl_target_attrA : benefit(1) { + %args = operands + %results = types + %attr = attribute + %0 = operation "linalg.matmul"(%args : !pdl.range) {"test.attrA" = %attr}-> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + // Match matmul operations inside @matmul_tensors with test.attrC set. + pdl.pattern @pdl_target_attrC : benefit(1) { + %args = operands + %results = types + %attr = attribute + %0 = operation "linalg.matmul"(%args : !pdl.range) {"test.attrC" = %attr}-> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target_attrA in %arg1 + transform.structured.tile %0 {sizes = [4, 4, 4]} + %1 = pdl_match @pdl_target_attrC in %arg1 + %2 = transform.get_closest_isolated_parent %1 + transform.structured.vectorize %2 + } +} + +// ----- + +// CHECK-LABEL: @vectorize_one +func.func @vectorize_one( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, + %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) + -> tensor<128x128xf32> { + // CHECK: vector.contract + %0 = linalg.matmul {test.attrA} + ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + func.return %0 : tensor<128x128xf32> +} + +func.func @vectorize_none( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, + %arg2: tensor<128x128xf32> {linalg.inplaceable = true}) + -> tensor<128x128xf32> { + // CHECK: linalg.matmul + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + func.return %0 : tensor<128x128xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %attr = attribute + %0 = operation "linalg.matmul"(%args : !pdl.range) {"test.attrA" = %attr}-> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = get_closest_isolated_parent %0 + transform.structured.vectorize %1 + } +} + +// ----- + +// CHECK-LABEL: @vectorize_all +func.func @vectorize_all( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, + %arg3: tensor<128x128xf32> {linalg.inplaceable = true}) + -> tensor<128x128xf32> { + // CHECK: vector.contract + %0 = linalg.matmul {test.attrA} + ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + // CHECK: vector.contract + %1 = linalg.matmul ins(%arg0, %0: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg3: tensor<128x128xf32>) + -> tensor<128x128xf32> + return %1 : tensor<128x128xf32> +} + +transform.sequence { +^bb0(%arg0: !pdl.operation): + transform.structured.vectorize %arg0 +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics +// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics // expected-remark @below {{applying transformation}} transform.test_transform_op @@ -385,3 +385,54 @@ // expected-error @below {{unexpected number of results (got 0 expected 3)}} transform.test_wrong_number_of_results %arg0 } + +// ----- + +func.func @foo() { + "op" () : () -> () + "op" () : () -> () + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @some : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "op"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 + // expected-error @below {{expected all applications of transform.test_wrong_number_of_multi_results to produce 1 results}} + transform.test_wrong_number_of_multi_results %0 + } +} + +// ----- + +func.func @foo() { + "op" () : () -> () + "op" () : () -> () + "op" () : () -> () + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @some : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "op"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 + // Transform matches 3 ops and produces 2 results. + %1:2 = transform.test_correct_number_of_multi_results %0 + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -232,6 +232,24 @@ return SmallVector{}; } +FailureOr> +mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( + Operation *op, transform::TransformState &state) { + static int count = 0; + if (count++ > 0) + return SmallVector{}; + OperationState opState(op->getLoc(), "foo"); + return SmallVector{OpBuilder(op).create(opState)}; +} + +FailureOr> +mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( + Operation *op, transform::TransformState &state) { + OperationState opState(op->getLoc(), "foo"); + return SmallVector{OpBuilder(op).create(opState), + OpBuilder(op).create(opState)}; +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -144,4 +144,33 @@ }]; } +def TestWrongNumberOfMultiResultsOp + : Op { + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$result); + let assemblyFormat = "$target attr-dict"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( + ::mlir::Operation *target, transform::TransformState &state); + }]; +} + +def TestCorrectNumberOfMultiResultsOp + : Op { + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$result1, + PDL_Operation:$result2); + let assemblyFormat = "$target attr-dict"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( + ::mlir::Operation *target, transform::TransformState &state); + }]; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD