diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -72,8 +72,28 @@ namespace detail { +/// Returns true if the block contains a contraction of the following form: +/// +/// %0 = (permutation-of(cu(block-argument-0), +/// cu(block-argument-1))) +/// %1 = (permutation-of(cu(%0), cu(block-argument-2))) +/// return-like cu(%1) +/// +/// where and are binary operations constituting a +/// contraction (in the canonical case, is a multiplication and +/// is an addition). The name and other properties of these operations +/// are checked by `isaPair`. All operands of all operations may be supplied +/// through a chain of side effect-free unary operations, such as casts, which +/// is denoted as `cu` above. +/// +/// When the body does not contain a contraction, a more precise description of +/// the failed precondition is send to the `errs` stream, if provided. +bool isContractionBody(Block &block, + function_ref isaPair, + llvm::raw_ostream &errs = llvm::nulls()); + /// Result of matching a Linalg generic against the predicates of it being a -/// contractiom. +/// contraction. enum class MatchContractionResult; /// Checks whether `op` conforms to ContractionOpInterface and populates diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td @@ -106,6 +106,12 @@ * `passthrough`: the body of the structured payload op only forwards inputs to the outputs (copy or broadcast). + * `contraction`: the body of the structured payload op is a contraction + of the form `((bbarg0, bbarg1), bbarg2)` where `` and + `` are binary operations whose names are specified in the attribute + and operands can be permuted and optionally forwarded through a chain of + unary side effect-free operations. + }], StructuredPredicate.extraDescription, [{ #### Return modes @@ -116,12 +122,54 @@ }]); let arguments = (ins TransformHandleTypeInterface:$operand_handle, OptionalAttr:$reduction_position, - UnitAttr:$passthrough); + UnitAttr:$passthrough, + OptionalAttr:$contraction); let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)"; let extraClassDeclaration = SingleOpMatcher.extraDeclaration; let hasVerifier = 1; } +def MatchStructuredClassifyContractionDimsOp + : Op { + let summary = + "Checks if an operation has contraction-like dimensions and returns them"; + let description = !strconcat([{ + Checks if the structured payload op has contraction-like dimensions as + follows: + + C(batch, m, n) += A(batch, m, k) * B(batch, k, n) + + That is: + + - 'batch' are parallel dimensions used in inputs and result; + - 'm' are parallel dimensions used in the LHS and result; + - 'n' are parallel dimensions used in rhe RHS and result; + - 'k' are reduction dimensions present only in LHS and RHS. + + Note that this doesn't check the operation in the body. + + }], StructuredPredicate.extraDescription, [{ + + #### Return modes + + Succeeds if the operation has the contraction-like dimensions, produces a + silenceable failure otherwise. + }]); + + let arguments = (ins TransformHandleTypeInterface:$operand_handle); + let results = (outs TransformParamTypeInterface:$batch, + TransformParamTypeInterface:$m, + TransformParamTypeInterface:$n, + TransformParamTypeInterface:$k); + let assemblyFormat = + "$operand_handle attr-dict `:` functional-type(operands, results)"; + let extraClassDeclaration = SingleOpMatcher.extraDeclaration; +} + class StructuredDimDescription { string description = !strconcat([{ The following }], kind ,[{ specifications are supported: diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -52,68 +52,106 @@ // ContractionOpInterface implementation //===----------------------------------------------------------------------===// -/// Return true if the use-def chain from `v` to `from` consists of 0 or more -/// unary single-operand operations. +/// If the value is defined by a chain of unary side effect-free, go up the +/// use-def chain until the first value that isn't defined by such an op. // TODO: relax to multi-operands with constants, which are technically unary ops // as needed (e.g. add5). -static bool isChainOfUnaryOpsFrom(Value v, Value from) { - while (true) { - if (v == from) - return true; - Operation *op = v.getDefiningOp(); - if (!op || op->getNumOperands() != 1) - return false; - v = op->getOperand(0); - }; +static Value getSourceSkipUnary(Value value) { + Operation *op = value.getDefiningOp(); + while (op && op->getNumOperands() == 1) { + auto iface = dyn_cast(op); + if (!iface || !iface.hasNoEffect()) + break; + value = op->getOperand(0); + op = value.getDefiningOp(); + } + return value; } -/// Return the unique instance of OpType in `block` if it is indeed unique. -/// Return null if none or more than 1 instances exist. -template -static OpType getSingleOpOfType(Block &block) { - OpType res = nullptr; - block.walk([&](OpType op) { - if (res) { - res = nullptr; - return WalkResult::interrupt(); - } - res = op; - return WalkResult::advance(); - }); - return res; -} +bool mlir::linalg::detail::isContractionBody( + Block &block, function_ref isaPair, + llvm::raw_ostream &errs) { + if (block.empty() || !block.back().mightHaveTrait()) { + errs << "no terminator in the block"; + return false; + } + + if (block.getNumArguments() != 3) { + errs << "expected block with 3 arguments"; + return false; + } + + Operation *terminator = block.getTerminator(); + if (terminator->getNumOperands() != 1) { + errs << "expected terminator with 1 operand"; + return false; + } + + Value yielded = getSourceSkipUnary(terminator->getOperand(0)); + Operation *reductionOp = yielded.getDefiningOp(); + if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { + errs << "expected reduction op to be binary"; + return false; + } + + Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0)); + Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1)); -/// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))` -/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent -/// unary operations that may change the type. -template -static bool isAddMul(Block &block) { - if (block.getNumArguments() != 3) + if (reductionLHS != block.getArgument(2) && + reductionRHS != block.getArgument(2)) { + errs << "expected reduction to take block argument #2 as one of the " + "operands (modulo unary casts)"; return false; - Operation *yieldOp = block.getTerminator(); - if (yieldOp->getNumOperands() != 1) + } + + Value contributed = getSourceSkipUnary( + isa(reductionLHS) ? reductionRHS : reductionLHS); + Operation *elementwiseOp = contributed.getDefiningOp(); + if (elementwiseOp->getNumResults() != 1 || + elementwiseOp->getNumOperands() != 2) { + errs << "expected elementwise op to be binary"; + return false; + } + + if (!isaPair(elementwiseOp, reductionOp)) { + errs << "expected reduction/elementwise op kind not satisfied"; return false; + } + + Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0)); + Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1)); + if ((elementwiseLHS == block.getArgument(0) && + elementwiseRHS == block.getArgument(1)) || + (elementwiseLHS == block.getArgument(1) && + elementwiseRHS == block.getArgument(0))) { + return true; + } - AddOpType addOp = getSingleOpOfType(block); - MulOpType mulOp = getSingleOpOfType(block); - if (!addOp || !mulOp) + errs << "expected elementwise op to apply to block arguments (modulo unary " + "casts)"; + return false; +} + +/// Returns true if the two operations are of the kinds specified by a pair of +/// consecutive template arguments. +template +static bool isPairTemplateImpl(Operation *add, Operation *mul) { + static_assert(sizeof...(Args) % 2 == 0, + "expected an even number of template arguments"); + if (isa(add) && isa(mul)) + return true; + + if constexpr (sizeof...(Args) > 0) + return isPairTemplateImpl(add, mul); + else return false; +} - Value argA = block.getArgument(0), argB = block.getArgument(1); - Value a = mulOp->getOperand(0), b = mulOp->getOperand(1); - Value mul = mulOp->getResult(0); - Value argC = block.getArgument(2); - Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1); - Value add = addOp->getResult(0); - Value res = yieldOp->getOperand(0); - // Result traces back to add. - auto un = isChainOfUnaryOpsFrom; - bool success = un(res, add); - // One of the operands of add traces back to argC, the other to the mul. - success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC)); - // One of the operands of mul traces back to argA, the other to argB. - success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA)); - return success; +/// Returns true if the block is a body of a contraction with the kinds of +/// operations given pairwise by template arguments. +template +static bool isContractionBody(Block &block) { + return linalg::detail::isContractionBody(block, &isPairTemplateImpl); } /// Given a `linalgOp` and one of its `opOperand`, returns the positions of the @@ -231,12 +269,16 @@ [](AffineMap m) { return !m.isProjectedPermutation(); })) return MatchContractionResult::NotProjectedPermutations; // TODO: more fields than add/mul. - if (!isAddMul(linalgOp->getRegion(0).front()) && - !isAddMul(linalgOp->getRegion(0).front()) && - !isAddMul( - linalgOp->getRegion(0).front()) && - !isAddMul(linalgOp->getRegion(0).front())) + // clang-format off + if (!::isContractionBody< + arith::MulFOp, arith::AddFOp, + arith::MulIOp, arith::AddIOp, + complex::MulOp, complex::AddOp, + arith::AndIOp, arith::OrIOp>( + *linalgOp.getBlock())) { return MatchContractionResult::NotAddMul; + } + // clang-format on if (dimensions) { FailureOr res = inferContractionDims(linalgOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" @@ -186,17 +187,78 @@ } return DiagnosedSilenceableFailure::success(); } + if (std::optional contractionOps = getContraction()) { + Block &body = linalgOp->getRegion(0).front(); + std::string message; + llvm::raw_string_ostream os(message); + bool result = linalg::detail::isContractionBody( + body, + [&](Operation *elem, Operation *red) { + return elem->getName().getStringRef() == + (*contractionOps)[0].cast().getValue() && + red->getName().getStringRef() == + (*contractionOps)[1].cast().getValue(); + }, + os); + if (result) + return DiagnosedSilenceableFailure::success(); + return emitSilenceableError() << "contraction: " << os.str(); + } return emitDefiniteFailure() << "unknown body condition"; } LogicalResult transform::MatchStructuredBodyOp::verify() { - if (getReductionPosition() && getPassthrough()) { - return emitOpError() << "reduction position and passthrough conditions are " - "mutually exclusive"; + int64_t numOptions = getReductionPosition().has_value() + getPassthrough() + + getContraction().has_value(); + + if (numOptions > 1) { + std::string attributeNames; + llvm::raw_string_ostream os(attributeNames); + llvm::interleaveComma(ArrayRef{getReductionPositionAttrName(), + getPassthroughAttrName(), + getContractionAttrName()}, + os); + return emitOpError() << "only one of {" << os.str() << "} is allowed"; + } + + if (std::optional contractionAttr = getContraction()) { + if (contractionAttr->size() != 2) { + return emitOpError() << "expects " << getContractionAttrName() + << " to contain two elements"; + } } return success(); } +//===----------------------------------------------------------------------===// +// MatchStructuredClassifyContractionDimsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MatchStructuredClassifyContractionDimsOp::matchOperation( + Operation *current, transform::TransformResults &results, + transform::TransformState &state) { + FailureOr contractionDims = + linalg::inferContractionDims(cast(current)); + if (failed(contractionDims)) + return emitSilenceableError() << "could not infer contraction dimensions"; + + MLIRContext *context = current->getContext(); + Builder builder(context); + auto makeI64Attrs = [&](ArrayRef values) { + return llvm::to_vector( + llvm::map_range(values, [&](unsigned value) -> Attribute { + return builder.getI64IntegerAttr(value); + })); + }; + results.setParams(getBatch().cast(), + makeI64Attrs(contractionDims->batch)); + results.setParams(getM().cast(), makeI64Attrs(contractionDims->m)); + results.setParams(getN().cast(), makeI64Attrs(contractionDims->n)); + results.setParams(getK().cast(), makeI64Attrs(contractionDims->k)); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Utilities for structured match predicates. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -777,7 +777,6 @@ // ----- - module attributes { transform.with_named_sequence } { transform.named_sequence @match_input_indexing_map(%arg0: !transform.any_op {transform.readonly}) -> (!transform.affine_map, !transform.any_op) { @@ -831,3 +830,79 @@ return } } + +// ----- + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_contraction(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.param, !transform.param, !transform.param, !transform.param) { + %1:4 = transform.match.structured %arg0 : (!transform.any_op) -> (!transform.param, !transform.param, !transform.param, !transform.param) { + ^bb0(%struct: !transform.any_op): + transform.match.structured.body %struct { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op + %0:4 = transform.match.structured.classify_contraction_dims %struct + : (!transform.any_op) -> (!transform.param, !transform.param, !transform.param, !transform.param) + transform.match.structured.yield %0#0, %0#1, %0#2, %0#3 + : !transform.param, !transform.param, !transform.param, !transform.param + } + transform.yield %arg0, %1#0, %1#1, %1#2, %1#3 : !transform.any_op, !transform.param, !transform.param, !transform.param, !transform.param + } + + transform.named_sequence @print_contraction( + %op: !transform.any_op {transform.readonly}, + %batch: !transform.param {transform.readonly}, + %m: !transform.param {transform.readonly}, + %n: !transform.param {transform.readonly}, + %k: !transform.param {transform.readonly}) { + transform.test_print_remark_at_operand %op, "contraction" : !transform.any_op + transform.test_print_param %batch, "batch dims" at %op : !transform.param, !transform.any_op + transform.test_print_param %m, "m dims" at %op : !transform.param, !transform.any_op + transform.test_print_param %n, "n dims" at %op : !transform.param, !transform.any_op + transform.test_print_param %k, "k dims" at %op : !transform.param, !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + %3 = transform.foreach_match in %arg0 @match_contraction -> @print_contraction : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +module attributes { transform.target_tag = "start_here" } { + func.func @matmul_simple(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf64> { + %cst = arith.constant 0.0 : f64 + %empty = tensor.empty() : tensor<10x15xf64> + %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf64>) -> tensor<10x15xf64> + // expected-remark @below {{contraction}} + // expected-remark @below {{batch dims}} + // expected-remark @below {{m dims 0}} + // expected-remark @below {{n dims 1}} + // expected-remark @below {{k dims 2}} + %result = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf32>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf64>) -> tensor<10x15xf64> + return %result : tensor<10x15xf64> + } + + func.func @double_batch(%lhs: tensor<40x10x50x20xf32>, %rhs: tensor<40x20x50x15xf32>) -> tensor<40x10x50x15xf32> { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<40x10x50x15xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<40x10x50x15xf32>) -> tensor<40x10x50x15xf32> + // expected-remark @below {{contraction}} + // expected-remark @below {{batch dims 0 : i64, 2 : i64}} + // expected-remark @below {{m dims 1}} + // expected-remark @below {{n dims 3}} + // expected-remark @below {{k dims 4}} + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2, d3)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<40x10x50x20xf32>, tensor<40x20x50x15xf32>) + outs(%fill : tensor<40x10x50x15xf32>) { + ^bb(%arg0: f32, %arg1: f32, %arg2: f32): + %0 = arith.mulf %arg0, %arg1 : f32 + %1 = arith.addf %arg2, %0 : f32 + linalg.yield %1 : f32 + } -> tensor<40x10x50x15xf32> + return %result : tensor<40x10x50x15xf32> + } +} diff --git a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir --- a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir @@ -64,7 +64,7 @@ ^bb0(%arg0: !transform.any_op): transform.match.structured %arg0 : !transform.any_op { ^bb1(%arg1: !transform.any_op): - // expected-error @below {{reduction position and passthrough conditions are mutually exclusive}} + // expected-error @below {{only one of {"reduction_position", "passthrough", "contraction"} is allowed}} transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op transform.match.structured.yield } diff --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics + +module attributes { transform.with_named_sequence } { + transform.named_sequence @_match_matmul_like( + %entry: !transform.any_op {transform.readonly}, + %rank: !transform.param {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.param, + !transform.type, !transform.type, !transform.type, + !transform.param, !transform.param, !transform.param, !transform.param) + + transform.named_sequence @match_bmm(%entry: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.param, + !transform.type, !transform.type, !transform.type, !transform.param) { + transform.match.operation_name %entry ["linalg.batch_matmul", "linalg.generic"] : !transform.any_op + %c3 = transform.param.constant 4 : i64 -> !transform.param + %fill, %bmm, %dims, %lhs_type, %rhs_type, %res_type, %batch, %m, %n, %k = + transform.include @_match_matmul_like failures(propagate) (%entry, %c3) + : (!transform.any_op, !transform.param) + -> (!transform.any_op, !transform.any_op, !transform.param, + !transform.type, !transform.type, !transform.type, + !transform.param, !transform.param, !transform.param, !transform.param) + + transform.yield %fill, %bmm, %dims, %lhs_type, %rhs_type, %res_type, %batch + : !transform.any_op, !transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type, !transform.param + } + + transform.named_sequence @print_bmm( + %fill: !transform.any_op {transform.readonly}, + %bmm: !transform.any_op {transform.readonly}, + %dims: !transform.param {transform.readonly}, + %lhs_type: !transform.type {transform.readonly}, + %rhs_type: !transform.type {transform.readonly}, + %res_type: !transform.type {transform.readonly}, + %batch: !transform.param {transform.readonly}) { + transform.test_print_remark_at_operand %fill, "fill" : !transform.any_op + transform.test_print_remark_at_operand %bmm, "batch matmul" : !transform.any_op + transform.test_print_param %dims, "dimensions" at %bmm : !transform.param, !transform.any_op + transform.test_print_param %lhs_type, "LHS type" at %bmm : !transform.type, !transform.any_op + transform.test_print_param %rhs_type, "RHS type" at %bmm : !transform.type, !transform.any_op + transform.test_print_param %res_type, "result type" at %bmm : !transform.type, !transform.any_op + transform.test_print_param %batch, "batch dimension" at %bmm : !transform.param, !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) { + ^bb(%root: !transform.any_op): + foreach_match in %root + @match_bmm -> @print_bmm + : (!transform.any_op) -> !transform.any_op + } +} + +func.func @bmm_simple(%lhs: tensor<40x10x20xf16>, %rhs: tensor<40x20x15xf32>) -> tensor<40x10x15xf64>{ + %cst = arith.constant 0.0 : f64 + %empty = tensor.empty() : tensor<40x10x15xf64> + // expected-remark @below {{fill}} + %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<40x10x15xf64>) -> tensor<40x10x15xf64> + // expected-remark @below {{batch matmul}} + // expected-remark @below {{dimensions 40 : i64, 10 : i64, 15 : i64, 20 : i64}} + // expected-remark @below {{LHS type f16}} + // expected-remark @below {{RHS type f32}} + // expected-remark @below {{result type f64}} + // expected-remark @below {{batch dimension 0}} + %result = linalg.batch_matmul ins(%lhs, %rhs: tensor<40x10x20xf16>, tensor<40x20x15xf32>) outs(%fill: tensor<40x10x15xf64>) -> tensor<40x10x15xf64> + return %result : tensor<40x10x15xf64> +} diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir --- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir +++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir @@ -1,36 +1,26 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter --verify-diagnostics +// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics module attributes { transform.with_named_sequence } { + transform.named_sequence @_match_matmul_like( + %entry: !transform.any_op {transform.readonly}, + %rank: !transform.param {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.param, + !transform.type, !transform.type, !transform.type, + !transform.param, !transform.param, !transform.param, !transform.param) + transform.named_sequence @match_matmul(%entry: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type) { - %c1 = transform.param.constant 1 : i64 -> !transform.param - %c2 = transform.param.constant 2 : i64 -> !transform.param - %capture:5 = transform.match.structured %entry : (!transform.any_op) - -> (!transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type) { - ^bb0(%struct: !transform.any_op): - transform.match.operation_name %struct ["linalg.matmul"] : !transform.any_op - %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param - - %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param - %n_inits = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param - transform.match.param.cmpi eq %n_inputs, %c2 : !transform.param - transform.match.param.cmpi eq %n_inits, %c1 : !transform.param - - %lhs = transform.match.structured.input %struct[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.match.structured.input %struct[1] : (!transform.any_op) -> !transform.any_value - %res = transform.match.structured.result %struct[0] : (!transform.any_op) -> !transform.any_value - %lhs_type = transform.get_type elemental %lhs : (!transform.any_value) -> !transform.type - %rhs_type = transform.get_type elemental %rhs : (!transform.any_value) -> !transform.type - %res_type = transform.get_type elemental %res : (!transform.any_value) -> !transform.type + transform.match.operation_name %entry ["linalg.matmul", "linalg.generic"] : !transform.any_op + %c3 = transform.param.constant 3 : i64 -> !transform.param + %fill, %matmul, %dims, %lhs_type, %rhs_type, %res_type, %kinds:4 = + transform.include @_match_matmul_like failures(propagate) (%entry, %c3) + : (!transform.any_op, !transform.param) + -> (!transform.any_op, !transform.any_op, !transform.param, + !transform.type, !transform.type, !transform.type, + !transform.param, !transform.param, !transform.param, !transform.param) - %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_op - transform.match.operation_name %init ["linalg.fill"] : !transform.any_op - - transform.match.structured.yield %init, %dims, %lhs_type, %rhs_type, %res_type - : !transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type - } - transform.yield %capture#0, %entry, %capture#1, %capture#2, %capture#3, %capture#4 + transform.yield %fill, %matmul, %dims, %lhs_type, %rhs_type, %res_type : !transform.any_op, !transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type } @@ -90,3 +80,29 @@ %result = linalg.matmul ins(%real_lhs, %rhs: tensor<10x20xf32>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf32>) -> tensor<10x15xf32> return %result : tensor<10x15xf32> } + +func.func @matmul_generic(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf64>{ + %cst = arith.constant 0.0 : f64 + %empty = tensor.empty() : tensor<10x15xf64> + // expected-remark @below {{fill}} + %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf64>) -> tensor<10x15xf64> + // expected-remark @below {{matmul}} + // expected-remark @below {{dimensions 10 : i64, 15 : i64, 20 : i64}} + // expected-remark @below {{LHS type f16}} + // expected-remark @below {{RHS type f32}} + // expected-remark @below {{result type f64}} + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf64>) { + ^bb(%arg0: f16, %arg1: f32, %arg2: f64): + %0 = arith.extf %arg0 : f16 to f32 + %1 = arith.mulf %0, %arg1 : f32 + %2 = arith.extf %1 : f32 to f64 + %3 = arith.addf %2, %arg2 : f64 + linalg.yield %3 : f64 + }-> tensor<10x15xf64> + return %result : tensor<10x15xf64> +} diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul_common.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul_common.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Transform/match_matmul_common.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s + +module attributes { transform.with_named_sequence } { + transform.named_sequence @_match_matmul_like( + %entry: !transform.any_op {transform.readonly}, + %rank: !transform.param {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.param, + !transform.type, !transform.type, !transform.type, + !transform.param, !transform.param, !transform.param, !transform.param) { + %c1 = transform.param.constant 1 : i64 -> !transform.param + %c2 = transform.param.constant 2 : i64 -> !transform.param + %capture:9 = transform.match.structured %entry : (!transform.any_op) + -> (!transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type, + !transform.param, !transform.param, !transform.param, !transform.param) { + ^bb0(%struct: !transform.any_op): + %op_rank = transform.match.structured.rank %struct : (!transform.any_op) -> !transform.param + transform.match.param.cmpi eq %rank, %op_rank : !transform.param + %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param + + %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param + %n_inits = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param + transform.match.param.cmpi eq %n_inputs, %c2 : !transform.param + transform.match.param.cmpi eq %n_inits, %c1 : !transform.param + + %lhs = transform.match.structured.input %struct[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.match.structured.input %struct[1] : (!transform.any_op) -> !transform.any_value + %res = transform.match.structured.result %struct[0] : (!transform.any_op) -> !transform.any_value + %lhs_type = transform.get_type elemental %lhs : (!transform.any_value) -> !transform.type + %rhs_type = transform.get_type elemental %rhs : (!transform.any_value) -> !transform.type + %res_type = transform.get_type elemental %res : (!transform.any_value) -> !transform.type + + %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_op + transform.match.operation_name %init ["linalg.fill"] : !transform.any_op + + transform.match.structured.body %struct { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op + %dim_kinds:4 = transform.match.structured.classify_contraction_dims %struct + : (!transform.any_op) -> (!transform.param, !transform.param, !transform.param, !transform.param) + + transform.match.structured.yield %init, %dims, %lhs_type, %rhs_type, %res_type, %dim_kinds#0, %dim_kinds#1, %dim_kinds#2, %dim_kinds#3 + : !transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type, + !transform.param, !transform.param, !transform.param, !transform.param + } + transform.yield %capture#0, %entry, %capture#1, %capture#2, %capture#3, %capture#4, + %capture#5, %capture#6, %capture#7, %capture#8 + : !transform.any_op, !transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type, + !transform.param, !transform.param, !transform.param, !transform.param + } +}