diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td @@ -244,7 +244,8 @@ // TODO: allow this to bind multiple inputs simultaneously after checking that // `transform.foreach` works well in matches. - let results = (outs Optional:$result); + let results = + (outs Optional>:$result); let assemblyFormat = "$operand_handle `[`" "custom($raw_position_list, $is_inverted, $is_all)" @@ -262,12 +263,36 @@ def MatchStructuredInputOp : MatchStructuredOperandOp<"match.structured.input"> { let summary = - "Captures input operand(s) of a structured operation in an op or value handle"; + "Captures input operand(s) of a structured operation"; let description = !strconcat([{ - Produces a transform dialect value handle associated with the payload value - supplied as input operand to the given structured payload operation, or an - operation handle to the structured payload operation producing said payload - value depending on the result type. + Produces a transform dialect value depending on the result type: + + - If the result type is a value handle, it will be associated with the input + operand(s) of the payload operation associated with the operand handle. + - If the result type is an operation handle, it will be associated with the + operation defining the input operand(s) of the payload operation associated + with the operand handle. + - If the result type is an affine map parameter type, it will be associated + with the indexing map that corresponds to the input operand(s) of the + payload operation associated with the operand handle. + + For example, given the following operation: + + ```mlir + %arg1 = some.op + linalg.matmul ins(%arg1, %arg2 : ...) outs(%arg3 : ...) + ``` + + in case of a successful match for operand 0 this operation will return, for + each of the respective cases above: + + - A handle to `%arg1` if the result is a value handle. + - A handle to `some.op` if the result is an operation handle. + - A parameter containing the LHS map of the matrix multiplication, i.e. + `affine_map<(d0, d1, d2) -> (d0, d2)>` if the result is an affine + map parameter. + + The match succeeds if the conditions specified as attributes succeed. }], StructuredDimDescription<"input">.description, @@ -288,12 +313,35 @@ def MatchStructuredInitOp : MatchStructuredOperandOp<"match.structured.init"> { let summary = - "Captures init operand(s) of a structured operation in an op or value handle"; + "Captures init operand(s) of a structured operation"; let description = !strconcat([{ - Produces a transform dialect value handle associated with the payload value - supplied as init(outs) operand to the given structured payload operation, - or an operation handle to the structured payload operation producing said - payload value depending on the result type. + Produces a transform dialect value depending on the result type: + - If the result type is a value handle, it will be associated with the init + operand(s) of the payload operation associated with the operand handle. + - If the result type is an operation handle, it will be associated with the + operation defining the init operand(s) of the payload operation associated + with the operand handle. + - If the result type is an affine map parameter type, it will be associated + with the indexing map that corresponds to the init operand(s) of the + payload operation associated with the operand handle. + + For example, given the following operation: + + ```mlir + %arg3 = linalg.fill + linalg.matmul ins(%arg1, %arg2 : ...) outs(%arg3 : ...) + ``` + + in case of a successful match for init operand 0 this operation will return, + for each of the respective cases above: + + - A handle to `%arg3` if the result is a value handle. + - A handle to `linalg.fill` if the result is an operation handle. + - A parameter containing the result map of the matrix multiplication, i.e. + `affine_map<(d0, d1, d2) -> (d0, d1)>` if the result is an affine + map parameter. + + The match succeeds if the conditions specified as attributes succeed. }], StructuredDimDescription<"init">.description, diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -443,6 +443,7 @@ def GetDefiningOp : TransformDialectOp<"get_defining_op", [DeclareOpInterfaceMethods, + MatchOpInterface, NavigationTransformOpTrait, MemoryEffectsOpInterface]> { let summary = "Get handle to the defining op of a value"; let description = [{ @@ -531,6 +532,25 @@ "functional-type(operands, results)"; } +def GetTypeOp : TransformDialectOp<"get_type", + [DeclareOpInterfaceMethods, + MatchOpInterface, + DeclareOpInterfaceMethods]> { + let summary = "Get a parameter containing the type of the given value"; + let description = [{ + This operation creates a new Transform parameter containing the + type(s) of the value(s) associated with the operand handle. + + This transform never fails. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$value, + UnitAttr:$elemental); + let results = (outs TransformParamTypeInterface:$type_param); + let assemblyFormat = "(`elemental` $elemental^)? $value attr-dict `:`" + "functional-type(operands, results)"; +} + def IncludeOp : TransformDialectOp<"include", [CallOpInterface, MatchOpInterface, @@ -838,6 +858,7 @@ [DeclareOpInterfaceMethods, + MatchOpInterface, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpAsmOpInterface, PossibleTopLevelTransformOpTrait, diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td @@ -13,6 +13,16 @@ include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" +def Transform_AffineMapParamType : TypeDef]> { + let description = [{ + Transform IR parameter value that can be associated with a list of affine + map attributes. + }]; + let mnemonic = "affine_map"; + let assemblyFormat = ""; +} + def Transform_AnyOpType : TypeDef]> { let description = [{ @@ -23,6 +33,15 @@ let assemblyFormat = ""; } +def Transform_AnyValue : TypeDef]> { + let description = [{ + Transform IR value that can be associated with a list of Payload IR values. + }]; + let mnemonic = "any_value"; + let assemblyFormat = ""; +} + def Transform_OperationType : TypeDef]> { let description = [{ @@ -52,12 +71,13 @@ let genVerifyDecl = 1; } -def Transform_AnyValue : TypeDef]> { +def Transform_TypeParamType : TypeDef]> { let description = [{ - Transform IR value that can be associated with a list of Payload IR values. + Transform IR parameter value that can be associated with a list of type + attributes. }]; - let mnemonic = "any_value"; + let mnemonic = "type"; let assemblyFormat = ""; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -428,6 +428,11 @@ if (!getResult()) continue; + if (isa(getResult().getType())) { + operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); + continue; + } + Value operand = linalgOp.getDpsInputOperand(position)->get(); if (isa(getResult().getType())) { operandMapping.emplace_back(operand); @@ -513,6 +518,11 @@ if (!getResult()) continue; + if (isa(getResult().getType())) { + operandMapping.emplace_back(AffineMapAttr::get(indexingMap)); + continue; + } + Value operand = linalgOp.getDpsInitOperand(position)->get(); if (isa(getResult().getType())) { operandMapping.emplace_back(operand); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1049,6 +1049,37 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// GetTypeOp +//===----------------------------------------------------------------------===// + +void transform::GetTypeOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getValue(), effects); + producesHandle(getResult(), effects); + onlyReadsPayload(effects); +} + +DiagnosedSilenceableFailure +transform::GetTypeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector params; + ArrayRef values = state.getPayloadValues(getValue()); + params.reserve(values.size()); + for (Value value : values) { + Type type = value.getType(); + if (getElemental()) { + if (auto shaped = dyn_cast(type)) { + type = shaped.getElementType(); + } + } + params.push_back(TypeAttr::get(type)); + } + results.setParams(getResult().cast(), params); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // IncludeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp @@ -38,6 +38,22 @@ >(); } +//===----------------------------------------------------------------------===// +// transform::AffineMapParamType +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::AffineMapParamType::checkPayload(Location loc, + ArrayRef payload) const { + for (Attribute attr : payload) { + if (!attr.isa()) { + return emitSilenceableError(loc) + << "expected affine map attribute, got " << attr; + } + } + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // transform::AnyOpType //===----------------------------------------------------------------------===// @@ -48,6 +64,16 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// transform::AnyValueType +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::AnyValueType::checkPayload(Location loc, + ArrayRef payload) const { + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // transform::OperationType //===----------------------------------------------------------------------===// @@ -103,11 +129,17 @@ } //===----------------------------------------------------------------------===// -// transform::AnyValueType +// transform::TypeParamType //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::AnyValueType::checkPayload(Location loc, - ArrayRef payload) const { +transform::TypeParamType::checkPayload(Location loc, + ArrayRef payload) const { + for (Attribute attr : payload) { + if (!attr.isa()) { + return emitSilenceableError(loc) + << "expected type attribute, got " << attr; + } + } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -17,7 +17,7 @@ // Entry point. Match any structured operation and emit at remark. transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { ^bb0(%arg0: !transform.any_op): - transform.foreach_match in %arg0 + transform.foreach_match in %arg0 @match_structured_empty -> @print_structured : (!transform.any_op) -> !transform.any_op } @@ -73,7 +73,7 @@ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { ^bb0(%arg0: !transform.any_op): - transform.foreach_match in %arg0 + transform.foreach_match in %arg0 @match_structured_suppress -> @do_nothing : (!transform.any_op) -> !transform.any_op } @@ -118,7 +118,7 @@ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { ^bb0(%arg0: !transform.any_op): - transform.foreach_match in %arg0 + transform.foreach_match in %arg0 @match_structured_body_passthrough -> @print_passthrough : (!transform.any_op) -> !transform.any_op } @@ -129,7 +129,7 @@ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) { - ^bb0(%arg0: f32, %arg1: f32): + ^bb0(%arg0: f32, %arg1: f32): linalg.yield %arg0 : f32 } -> tensor<2xf32> @@ -137,7 +137,7 @@ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) { - ^bb0(%arg0: f32, %arg1: f32): + ^bb0(%arg0: f32, %arg1: f32): %0 = arith.mulf %arg0, %arg1 : f32 linalg.yield %0 : f32 } -> tensor<2xf32> @@ -168,7 +168,7 @@ transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { ^bb0(%arg0: !transform.any_op): - transform.foreach_match in %arg0 + transform.foreach_match in %arg0 @match_structured_body_reduction -> @print_reduction : (!transform.any_op) -> !transform.any_op } @@ -230,8 +230,8 @@ transform.named_sequence @match_dimension_capture(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { // Capture multiple dimension values. Suppress failures so we can print them anyway after the capture. - %0:9 = transform.match.structured failures(suppress) %arg0 - : (!transform.any_op) -> (!transform.any_op, !transform.param, !transform.param, !transform.param, + %0:9 = transform.match.structured failures(suppress) %arg0 + : (!transform.any_op) -> (!transform.any_op, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param) { ^bb0(%arg1: !transform.any_op): // This also tests the positional specification used by other ops, which may not test it again. @@ -243,8 +243,8 @@ %6 = transform.match.structured.dim %arg1[except(-1)] : (!transform.any_op) -> !transform.param %7 = transform.match.structured.dim %arg1[except(0, -2)] : (!transform.any_op) -> !transform.param %8 = transform.match.structured.dim %arg1[0, -3] : (!transform.any_op) -> !transform.param - transform.match.structured.yield %arg1, %1, %2, %3, %4, %5, %6, %7, %8 - : !transform.any_op, !transform.param, !transform.param, !transform.param, + transform.match.structured.yield %arg1, %1, %2, %3, %4, %5, %6, %7, %8 + : !transform.any_op, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param, !transform.param } transform.test_print_param %0#1, "dimensions all:" at %0#0 : !transform.param, !transform.any_op @@ -280,7 +280,7 @@ } func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } { - // The last does not emit anything because it fails to match + // The last does not emit anything because it fails to match // due to 0 and -3 being the same dimension in the 3D case. // expected-remark @below {{dimensions all: 2 : i64, 3 : i64, 4 : i64}} // expected-remark @below {{dimension 0: 2 : i64}} @@ -404,7 +404,7 @@ } transform.yield %arg0, %bw : !transform.any_op, !transform.param } - + transform.named_sequence @print_bitwidth(%arg0: !transform.any_op {transform.readonly}, %arg1: !transform.param {transform.readonly}) { transform.test_print_param %arg1, "bitwidth:" at %arg0 : !transform.param, !transform.any_op transform.yield @@ -417,7 +417,7 @@ } func.func @payload(%f32: f32, %tf32: tensor, - %index: index, %tindex: tensor) + %index: index, %tindex: tensor) attributes { transform.target_tag = "start_here" } { // expected-remark @below {{bitwidth: 32}} linalg.fill ins(%f32: f32) outs(%tf32: tensor) -> tensor @@ -429,7 +429,7 @@ // ----- module attributes { transform.with_named_sequence } { - transform.named_sequence @match_init(%arg0: !transform.any_op {transform.readonly}) + transform.named_sequence @match_init(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) { %outs:3 = transform.match.structured failures(suppress) %arg0 : (!transform.any_op) -> (!transform.any_value, !transform.any_value, !transform.any_op) { @@ -441,7 +441,7 @@ } transform.yield %arg0, %outs#0, %outs#1, %outs#2 : !transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op } - + transform.named_sequence @print_init(%arg0: !transform.any_op {transform.readonly}, %arg1: !transform.any_value {transform.readonly}, %arg2: !transform.any_value {transform.readonly}, @@ -459,21 +459,21 @@ } - func.func @payload(%f32: f32, + func.func @payload(%f32: f32, // expected-remark @below {{output 0}} // expected-remark @below {{all output}} // expected-note @below {{value handle points to a block argument #1 in block #0 in region #0}} %tf32: tensor, // expected-remark @below {{all output}} // expected-note @below {{value handle points to a block argument #2 in block #0 in region #0}} - %tf32_2: tensor) + %tf32_2: tensor) attributes { transform.target_tag = "start_here" } { // expected-remark @below {{output 0}} // expected-remark @below {{output producer}} // expected-remark @below {{all output}} // expected-note @below {{value handle points to an op result #0}} %0 = linalg.fill ins(%f32: f32) outs(%tf32: tensor) -> tensor - + linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] @@ -488,7 +488,7 @@ // ----- module attributes { transform.with_named_sequence } { - transform.named_sequence @match_init_0_permutation(%arg0: !transform.any_op {transform.readonly}) + transform.named_sequence @match_init_0_permutation(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { ^bb0(%arg1: !transform.any_op): @@ -497,7 +497,7 @@ } transform.yield %0 : !transform.any_op } - transform.named_sequence @match_init_1_permutation(%arg0: !transform.any_op {transform.readonly}) + transform.named_sequence @match_init_1_permutation(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { ^bb0(%arg1: !transform.any_op): @@ -506,7 +506,7 @@ } transform.yield %0 : !transform.any_op } - transform.named_sequence @match_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly}) + transform.named_sequence @match_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { ^bb0(%arg1: !transform.any_op): @@ -515,7 +515,7 @@ } transform.yield %0 : !transform.any_op } - + transform.named_sequence @print_init_0_permutation(%arg0: !transform.any_op {transform.readonly}) { transform.test_print_remark_at_operand %arg0, "matched output 0 permutation" : !transform.any_op transform.yield @@ -537,10 +537,10 @@ transform.yield } - func.func @payload(%f32: f32, + func.func @payload(%f32: f32, %oned: tensor, %oned2: tensor, - %twod: tensor) + %twod: tensor) attributes { transform.target_tag = "start_here" } { // expected-remark @below {{matched output 2 projected permutation}} linalg.generic { @@ -575,9 +575,9 @@ module attributes { transform.with_named_sequence } { - transform.named_sequence @match_num_io(%arg0: !transform.any_op {transform.readonly}) + transform.named_sequence @match_num_io(%arg0: !transform.any_op {transform.readonly}) -> (!transform.param, !transform.param, !transform.any_op) { - %0:3 = transform.match.structured failures(propagate) %arg0 + %0:3 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> (!transform.param, !transform.param, !transform.any_op) { ^bb0(%arg1: !transform.any_op): %1 = transform.match.structured.num_inputs %arg1 : (!transform.any_op) -> !transform.param @@ -587,7 +587,7 @@ transform.yield %0#0, %0#1, %0#2 : !transform.param, !transform.param, !transform.any_op } - + transform.named_sequence @print_num_io( %arg0: !transform.param {transform.readonly}, %arg1: !transform.param {transform.readonly}, @@ -604,10 +604,10 @@ transform.yield } - func.func @payload(%f32: f32, + func.func @payload(%f32: f32, %oned: tensor, %oned2: tensor, - %twod: tensor) + %twod: tensor) attributes { transform.target_tag = "start_here" } { // expected-remark @below {{inputs 1}} // expected-remark @below {{outputs 3}} @@ -641,9 +641,9 @@ // ----- module attributes { transform.with_named_sequence } { - transform.named_sequence @match_rank(%arg0: !transform.any_op {transform.readonly}) + transform.named_sequence @match_rank(%arg0: !transform.any_op {transform.readonly}) -> (!transform.param, !transform.any_op) { - %0:2 = transform.match.structured failures(propagate) %arg0 + %0:2 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> (!transform.param, !transform.any_op) { ^bb0(%arg1: !transform.any_op): %1 = transform.match.structured.rank %arg1 : (!transform.any_op) -> !transform.param @@ -652,7 +652,7 @@ transform.yield %0#0, %0#1 : !transform.param, !transform.any_op } - + transform.named_sequence @print_rank(%arg0: !transform.param {transform.readonly}, %arg2: !transform.any_op {transform.readonly}) { transform.test_print_param %arg0, "rank" at %arg2 : !transform.param, !transform.any_op @@ -665,8 +665,8 @@ transform.yield } - func.func @payload(%f32: f32, - %twod: tensor<42x42xf32>) + func.func @payload(%f32: f32, + %twod: tensor<42x42xf32>) attributes { transform.target_tag = "start_here" } { %0 = tensor.empty() : tensor<42x42xf32> // expected-remark @below {{rank 2}} @@ -681,9 +681,9 @@ // ----- module attributes { transform.with_named_sequence } { - transform.named_sequence @match_single_result(%arg0: !transform.any_op {transform.readonly}) + transform.named_sequence @match_single_result(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_op) { - %0:2 = transform.match.structured failures(propagate) %arg0 + %0:2 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) { ^bb0(%arg1: !transform.any_op): %1 = transform.match.structured.result %arg1[0] { single } : (!transform.any_op) -> !transform.any_op @@ -693,7 +693,7 @@ } transform.named_sequence @match_result_value(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_value, !transform.any_op) { - %0:2 = transform.match.structured failures(propagate) %arg0 + %0:2 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> (!transform.any_value, !transform.any_op) { ^bb0(%arg1: !transform.any_op): %1 = transform.match.structured.result %arg1[0] : (!transform.any_op) -> !transform.any_value @@ -701,9 +701,9 @@ } transform.yield %0#0, %0#1 : !transform.any_value, !transform.any_op } - transform.named_sequence @match_any_result(%arg0: !transform.any_op {transform.readonly}) + transform.named_sequence @match_any_result(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) { - %0 = transform.match.structured failures(propagate) %arg0 + %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op { ^bb0(%arg1: !transform.any_op): %1 = transform.match.structured.result %arg1[-1] { any } : (!transform.any_op) -> !transform.any_op @@ -711,7 +711,7 @@ } transform.yield %0 : !transform.any_op } - + transform.named_sequence @print_single_result(%arg0: !transform.any_op {transform.readonly}, %arg2: !transform.any_op {transform.readonly}) { transform.test_print_remark_at_operand %arg2, "matched single result" : !transform.any_op @@ -738,7 +738,7 @@ } func.func @payload(%f32: f32, %f322: f32, %f323: f32, - %twod: tensor<42x42xf32>) + %twod: tensor<42x42xf32>) attributes { transform.target_tag = "start_here" } { %0 = tensor.empty() : tensor<42x42xf32> @@ -774,3 +774,60 @@ return } } + +// ----- + + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_input_indexing_map(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.affine_map, !transform.any_op) { + %0 = transform.match.structured failures(propagate) %arg0 + : (!transform.any_op) -> !transform.affine_map { + ^bb0(%arg1: !transform.any_op): + %1 = transform.match.structured.input %arg1[0] : (!transform.any_op) -> !transform.affine_map + transform.match.structured.yield %1 : !transform.affine_map + } + transform.yield %0, %arg0 : !transform.affine_map, !transform.any_op + } + transform.named_sequence @match_init_indexing_map(%arg0: !transform.any_op {transform.readonly}) + -> (!transform.affine_map, !transform.any_op) { + %0 = transform.match.structured failures(propagate) %arg0 + : (!transform.any_op) -> !transform.affine_map { + ^bb0(%arg1: !transform.any_op): + %1 = transform.match.structured.init %arg1[0] : (!transform.any_op) -> !transform.affine_map + transform.match.structured.yield %1 : !transform.affine_map + } + transform.yield %0, %arg0 : !transform.affine_map, !transform.any_op + } + + transform.named_sequence @print_indexing_map_1(%arg0: !transform.affine_map {transform.readonly}, + %arg1: !transform.any_op {transform.readonly}) { + transform.test_print_param %arg0, "indexing map 1" at %arg1 : !transform.affine_map, !transform.any_op + transform.yield + } + transform.named_sequence @print_indexing_map_2(%arg0: !transform.affine_map {transform.readonly}, + %arg1: !transform.any_op {transform.readonly}) { + transform.test_print_param %arg0, "indexing map 2" at %arg1 : !transform.affine_map, !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + %3 = transform.foreach_match in %arg0 @match_input_indexing_map -> @print_indexing_map_1 : (!transform.any_op) -> !transform.any_op + %4 = transform.foreach_match in %3 @match_init_indexing_map -> @print_indexing_map_2 : (!transform.any_op) -> !transform.any_op + transform.yield + } + + func.func @payload(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>) + attributes { transform.target_tag = "start_here" } { + %out = tensor.empty() : tensor<32x32xf32> + %cst = arith.constant 1.0 : f32 + // expected-remark @below {{indexing map 1 affine_map<(d0, d1) -> ()>}} + // expected-remark @below {{indexing map 2 affine_map<(d0, d1) -> (d0, d1)>}} + %res = linalg.fill ins(%cst : f32) outs(%out : tensor<32x32xf32>) -> tensor<32x32xf32> + // expected-remark @below {{indexing map 1 affine_map<(d0, d1, d2) -> (d0, d2)>}} + // expected-remark @below {{indexing map 2 affine_map<(d0, d1, d2) -> (d0, d1)>}} + linalg.matmul ins(%lhs, %rhs : tensor<32x32xf32>, tensor<32x32xf32>) outs(%res : tensor<32x32xf32>) -> tensor<32x32xf32> + return + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -80,7 +80,7 @@ ^bb0(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-note @below {{for this parameter}} - %1 = transform.test_produce_integer_param_with_type i64 : !transform.param + %1 = transform.test_produce_param (0 : i64) : !transform.param // expected-error @below {{expected as many parameter values (0) as target ops (2)}} transform.structured.tile %0 [%1, %1, %1] : (!transform.any_op, !transform.param, !transform.param, !transform.param) diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1037,7 +1037,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - %0 = transform.test_produce_integer_param_with_type i32 : !transform.test_dialect_param + %0 = transform.test_produce_param (0 : i32) : !transform.test_dialect_param // expected-remark @below {{0 : i32}} transform.test_print_param %0 : !transform.test_dialect_param } @@ -1047,7 +1047,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): // expected-error @below {{expected the type of the parameter attribute ('i32') to match the parameter type ('i64')}} - transform.test_produce_integer_param_with_type i32 : !transform.param + transform.test_produce_param (0 : i32) : !transform.param } // ----- @@ -1860,3 +1860,58 @@ // expected-remark @below{{1}} test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op } + +// ----- + +func.func @cast(%arg0: f32) -> f64 { + // expected-remark @below{{f64}} + %0 = arith.extf %arg0 : f32 to f64 + return %0 : f64 +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["arith.extf"]} in %arg0 : (!transform.any_op) -> !transform.op<"arith.extf"> + %1 = transform.get_result %0[0] : (!transform.op<"arith.extf">) -> !transform.any_value + %2 = transform.get_type %1 : (!transform.any_value) -> !transform.type + transform.test_print_param %2 at %0 : !transform.type, !transform.op<"arith.extf"> + transform.yield +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected type attribute, got 0 : i32}} + transform.test_produce_param (0 : i32) : !transform.type +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected affine map attribute, got 0 : i32}} + transform.test_produce_param (0 : i32) : !transform.affine_map +} + +// ----- + +// CHECK-LABEL: @type_param_anchor +func.func private @type_param_anchor() + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // CHECK: test_produce_param(f32) : !transform.type + transform.test_produce_param(f32) : !transform.type +} + +// ----- + +// CHECK-LABEL: @affine_map_param_anchor +func.func private @affine_map_param_anchor() + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // CHECK: test_produce_param(#{{.*}}) : !transform.affine_map + transform.test_produce_param(affine_map<(d0) -> ()>) : !transform.affine_map +} diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --verify-diagnostics + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_matmul(%entry: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.param, + !transform.type, !transform.type, !transform.type) { + %c1 = transform.param.constant 1 : i64 -> !transform.param + %c2 = transform.param.constant 2 : i64 -> !transform.param + %capture:5 = transform.match.structured %entry : (!transform.any_op) + -> (!transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type) { + ^bb0(%struct: !transform.any_op): + transform.match.operation_name %struct ["linalg.matmul"] : !transform.any_op + %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param + + %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param + %n_inits = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param + transform.match.param.cmpi eq %n_inputs, %c2 : !transform.param + transform.match.param.cmpi eq %n_inits, %c1 : !transform.param + + %lhs = transform.match.structured.input %struct[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.match.structured.input %struct[1] : (!transform.any_op) -> !transform.any_value + %res = transform.match.structured.result %struct[0] : (!transform.any_op) -> !transform.any_value + %lhs_type = transform.get_type elemental %lhs : (!transform.any_value) -> !transform.type + %rhs_type = transform.get_type elemental %rhs : (!transform.any_value) -> !transform.type + %res_type = transform.get_type elemental %res : (!transform.any_value) -> !transform.type + + %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_op + transform.match.operation_name %init ["linalg.fill"] : !transform.any_op + + transform.match.structured.yield %init, %dims, %lhs_type, %rhs_type, %res_type + : !transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type + } + transform.yield %capture#0, %entry, %capture#1, %capture#2, %capture#3, %capture#4 + : !transform.any_op, !transform.any_op, !transform.param, !transform.type, !transform.type, !transform.type + } + + transform.named_sequence @print_matmul( + %fill: !transform.any_op {transform.readonly}, + %matmul: !transform.any_op {transform.readonly}, + %dims: !transform.param {transform.readonly}, + %lhs_type: !transform.type {transform.readonly}, + %rhs_type: !transform.type {transform.readonly}, + %res_type: !transform.type {transform.readonly}) { + transform.test_print_remark_at_operand %fill, "fill" : !transform.any_op + transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op + transform.test_print_param %dims, "dimensions" at %matmul : !transform.param, !transform.any_op + transform.test_print_param %lhs_type, "LHS type" at %matmul : !transform.type, !transform.any_op + transform.test_print_param %rhs_type, "RHS type" at %matmul : !transform.type, !transform.any_op + transform.test_print_param %res_type, "result type" at %matmul : !transform.type, !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) { + ^bb(%root: !transform.any_op): + foreach_match in %root + @match_matmul -> @print_matmul + : (!transform.any_op) -> !transform.any_op + } +} + +func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf64>{ + %cst = arith.constant 0.0 : f64 + %empty = tensor.empty() : tensor<10x15xf64> + // expected-remark @below {{fill}} + %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf64>) -> tensor<10x15xf64> + // expected-remark @below {{matmul}} + // expected-remark @below {{dimensions 10 : i64, 15 : i64, 20 : i64}} + // expected-remark @below {{LHS type f16}} + // expected-remark @below {{RHS type f32}} + // expected-remark @below {{result type f64}} + %result = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf64>) -> tensor<10x15xf64> + return %result : tensor<10x15xf64> +} + +func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> { + %cst = arith.constant 0.0 : f64 + %empty = tensor.empty() : tensor<10x15xf32> + + // expected-remark @below {{fill}} + %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32> + + %real_lhs = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32> + + // expected-remark @below {{matmul}} + // expected-remark @below {{dimensions 10 : i64, 15 : i64, 20 : i64}} + // expected-remark @below {{LHS type f32}} + // expected-remark @below {{RHS type f32}} + // expected-remark @below {{result type f32}} + %result = linalg.matmul ins(%real_lhs, %rhs: tensor<10x20xf32>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf32>) -> tensor<10x15xf32> + return %result : tensor<10x15xf32> +} diff --git a/mlir/test/Integration/Dialect/Transform/match_reduction.mlir b/mlir/test/Integration/Dialect/Transform/match_reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Transform/match_reduction.mlir @@ -0,0 +1,319 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --verify-diagnostics + +module attributes { transform.with_named_sequence } { + transform.named_sequence @_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly}) + -> (!transform.any_op) { + %c1 = transform.param.constant 1 : i64 -> !transform.param + + transform.match.structured %entry : !transform.any_op { + ^bb0(%struct: !transform.any_op): + transform.match.structured.dim %struct[all] {parallel} : !transform.any_op + transform.match.structured.input %struct[all] {projected_permutation} : !transform.any_op + transform.match.structured.init %struct[all] {permutation} : !transform.any_op + %ni = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param + transform.match.param.cmpi eq %ni, %c1 : !transform.param + } + transform.yield %entry : !transform.any_op + } + + transform.named_sequence @fill_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, + !transform.param, !transform.param, !transform.param) { + %c1 = transform.param.constant 1 : i64 -> !transform.param + %c2 = transform.param.constant 2 : i64 -> !transform.param + %c4 = transform.param.constant 4 : i64 -> !transform.param + + %rk, %dms, %bw, %operand_o, %init_v, %trailing_o = transform.match.structured failures(propagate) %entry + : (!transform.any_op) -> (!transform.param, !transform.param, !transform.param, + !transform.any_op, !transform.any_value, !transform.any_op) { + ^bb0(%struct: !transform.any_op): + %rank = transform.match.structured.rank %struct : (!transform.any_op) -> !transform.param + transform.match.param.cmpi ge %rank, %c2 : !transform.param + transform.match.param.cmpi le %rank, %c4 : !transform.param + + transform.match.structured.dim %struct[-1] {reduction} : !transform.any_op + transform.match.structured.dim %struct[except(-1)] {parallel} : !transform.any_op + %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param + + %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param + %n_outputs = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param + transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param + transform.match.param.cmpi eq %n_outputs, %c1 : !transform.param + + transform.match.structured.input %struct[0] {projected_permutation} : !transform.any_op + transform.match.structured.init %struct[0] {projected_permutation} : !transform.any_op + %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_value + + // This danse is necessary to create an empty handle if there is no single + // user without failing the entire match + %trailing_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) { + ^bb0(%struct_inner: !transform.any_op): + %result = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op { + ^bb0(%struct_inner_inner: !transform.any_op): + %result_inner = transform.match.structured.result %struct_inner_inner[0] {single} : (!transform.any_op) -> !transform.any_op + %trailing = transform.include @_reduce_leading_trailing failures(propagate) (%result_inner) : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield %trailing : !transform.any_op + } + transform.yield %result: !transform.any_op + } + + // Suppress errors as a way to implement optionality. We cannot suppress them in + // the include because it keeps matching after "get_defining_op" fails, which + // breaks the single-op precondition of the following ops. We don't want to + // propagate that failure though. + // + // Additionally, we cannot put the sequence inside the call because its first + // operand must be an operation handle (the verifier asserts!) and there is + // no such handle available there. + // + // TODO: extend the structured matching to gracefully handle empty handles + // or provide the suppress-errors-but-stop failure mode for includes to + // implement optionality. + %operand_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) { + ^bb0(%struct_inner: !transform.any_op): + %operand3 = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op { + ^bb1(%struct_inner_inner: !transform.any_op): + %operand = transform.match.structured.input %struct_inner_inner[0] : (!transform.any_op) -> !transform.any_op + %operand2 = transform.include @_reduce_leading_trailing failures(propagate) (%operand) : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield %operand2 : !transform.any_op + } + transform.yield %operand3 : !transform.any_op + } + + %bitwidth = transform.match.structured.elemental_bitwidth %init : (!transform.any_value) -> !transform.param + + transform.match.structured.body %struct { reduction_position = 0 } : !transform.any_op + transform.match.structured.yield %rank, %dims, %bitwidth, %operand_optional, %init, %trailing_optional + : !transform.param, !transform.param, !transform.param, + !transform.any_op, !transform.any_value, !transform.any_op + } + + %init_o = transform.get_defining_op %init_v : (!transform.any_value) -> !transform.any_op + transform.match.operation_name %init_o ["linalg.fill"] : !transform.any_op + + transform.yield %operand_o, %init_o, %entry, %trailing_o, %rk, %dms, %bw + : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, + !transform.param, !transform.param, !transform.param + } + + transform.named_sequence @print_reduce_leading_trailing( + %leading: !transform.any_op {transform.readonly}, + %fill: !transform.any_op {transform.readonly}, + %reduction: !transform.any_op {transform.readonly}, + %trailing: !transform.any_op {transform.readonly}, + %rank: !transform.param {transform.readonly}, + %dims: !transform.param {transform.readonly}, + %bitwidth: !transform.param {transform.readonly}) { + transform.test_print_remark_at_operand %leading, "leading" : !transform.any_op + transform.test_print_remark_at_operand %fill, "fill" : !transform.any_op + transform.test_print_remark_at_operand %reduction, "reduction" : !transform.any_op + transform.test_print_remark_at_operand %trailing, "trailing" : !transform.any_op + transform.test_print_param %rank, "rank" at %reduction : !transform.param, !transform.any_op + transform.test_print_param %dims, "dimensions" at %reduction : !transform.param, !transform.any_op + transform.test_print_param %bitwidth, "bitwidth" at %reduction : !transform.param, !transform.any_op + transform.yield + } + + transform.sequence failures(propagate) { + ^bb(%root: !transform.any_op): + foreach_match in %root + @fill_reduce_leading_trailing -> @print_reduce_leading_trailing + : (!transform.any_op) -> !transform.any_op + } +} + +!in_tensor_t = tensor<8x64xf32> +!out_tensor_t = tensor<8xf32> + +func.func @eltwise_reduce(%arg : !in_tensor_t) -> (!out_tensor_t) { + %cst = arith.constant -0.000000e+00 : f32 + + %0 = tensor.empty() : !out_tensor_t + // expected-remark @below {{fill}} + %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t + %2 = tensor.empty() : !in_tensor_t + // expected-remark @below {{leading}} + %3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg3 : f32 + %5 = arith.addf %4, %4 : f32 + linalg.yield %5 : f32 + } -> !in_tensor_t + + // expected-remark @below {{reduction}} + // expected-remark @below {{rank 2}} + // expected-remark @below {{dimensions 8 : i64, 64 : i64}} + // expected-remark @below {{bitwidth 32 : i64}} + %6 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> !out_tensor_t + + return %6 : !out_tensor_t +} + +func.func @reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) { + %cst = arith.constant -0.000000e+00 : f32 + + %0 = tensor.empty() : !out_tensor_t + // expected-remark @below {{fill}} + %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t + // expected-remark @below {{reduction}} + // expected-remark @below {{rank 2}} + // expected-remark @below {{dimensions 8 : i64, 64 : i64}} + // expected-remark @below {{bitwidth 32 : i64}} + %5 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> !out_tensor_t + + %6 = tensor.empty() : !out_tensor_t + // expected-remark @below {{trailing}} + %7 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%5 : !out_tensor_t) outs(%6 : !out_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = math.sqrt %arg3 : f32 + linalg.yield %4 : f32 + } -> !out_tensor_t + return %7 : !out_tensor_t +} + +func.func @eltwise_reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) { + %cst = arith.constant -0.000000e+00 : f32 + + %0 = tensor.empty() : !out_tensor_t + // expected-remark @below {{fill}} + %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t + %2 = tensor.empty() : !in_tensor_t + // expected-remark @below {{leading}} + %3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg3 : f32 + %5 = arith.addf %4, %4 : f32 + linalg.yield %5 : f32 + } -> !in_tensor_t + + // expected-remark @below {{reduction}} + // expected-remark @below {{rank 2}} + // expected-remark @below {{dimensions 8 : i64, 64 : i64}} + // expected-remark @below {{bitwidth 32 : i64}} + %6 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> !out_tensor_t + + %7 = tensor.empty() : !out_tensor_t + // expected-remark @below {{trailing}} + %8 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = math.sqrt %arg3 : f32 + linalg.yield %4 : f32 + } -> !out_tensor_t + + + return %8 : !out_tensor_t +} + +func.func @eltwise_reduce_eltwise_swapped(%arg : !in_tensor_t) -> (!out_tensor_t) { + %cst = arith.constant -0.000000e+00 : f32 + + %2 = tensor.empty() : !in_tensor_t + // expected-remark @below {{leading}} + %3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg3 : f32 + %5 = arith.addf %4, %4 : f32 + linalg.yield %5 : f32 + } -> !in_tensor_t + + %0 = tensor.empty() : !out_tensor_t + // expected-remark @below {{fill}} + %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t + // expected-remark @below {{reduction}} + // expected-remark @below {{rank 2}} + // expected-remark @below {{dimensions 8 : i64, 64 : i64}} + // expected-remark @below {{bitwidth 32 : i64}} + %6 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> !out_tensor_t + + %7 = tensor.empty() : !out_tensor_t + // expected-remark @below {{trailing}} + %8 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = math.sqrt %arg3 : f32 + linalg.yield %4 : f32 + } -> !out_tensor_t + + + return %8 : !out_tensor_t +} + +func.func @reduction_with_extra_op_in_func(%arg0: tensor<8x479xf32>, %arg1: tensor<32x32xf32>) -> (tensor<8xf32>, tensor<32xf32>) { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<8xf32> + // expected-remark @below {{fill}} + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8xf32>) -> tensor<8xf32> + // expected-remark @below {{reduction}} + // expected-remark @below {{rank 2}} + // expected-remark @below {{dimensions 8 : i64, 479 : i64}} + // expected-remark @below {{bitwidth 32 : i64}} + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<8x479xf32>) + outs(%fill : tensor<8xf32>) { + ^bb0(%in: f32, %out: f32): + %6 = arith.addf %in, %out : f32 + linalg.yield %6 : f32 + } -> tensor<8xf32> + + %empty2 = tensor.empty() : tensor<32xf32> + %fill2 = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<32xf32>) -> tensor<32xf32> + return %result, %fill2 : tensor<8xf32>, tensor<32xf32> +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -633,21 +633,13 @@ } DiagnosedSilenceableFailure -mlir::test::TestProduceIntegerParamWithTypeOp::apply( - transform::TransformRewriter &rewriter, - transform::TransformResults &results, transform::TransformState &state) { - Attribute zero = IntegerAttr::get(getType(), 0); - results.setParams(llvm::cast(getResult()), zero); +mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + results.setParams(llvm::cast(getResult()), getAttr()); return DiagnosedSilenceableFailure::success(); } -LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() { - if (!llvm::isa(getType())) { - return emitOpError() << "expects an integer type"; - } - return success(); -} - void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getIn(), effects); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -428,15 +428,14 @@ let cppNamespace = "::mlir::test"; } -def TestProduceIntegerParamWithTypeOp - : Op]> { - let arguments = (ins TypeAttr:$type); + let arguments = (ins AnyAttr:$attr); let results = (outs TransformParamTypeInterface:$result); - let assemblyFormat = "$type attr-dict `:` type($result)"; + let assemblyFormat = "`(` $attr `)` attr-dict `:` type($result)"; let cppNamespace = "::mlir::test"; - let hasVerifier = 1; } def TestProduceTransformParamOrForwardOperandOp