diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -99,9 +99,10 @@ computational operations, which can be empty. }]; - let arguments = (ins PDL_Operation:$target); - let results = (outs PDL_Operation:$transformed); - let assemblyFormat = "$target attr-dict"; + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -124,11 +125,11 @@ }]; let arguments = - (ins PDL_Operation:$target, + (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$tile_sizes, DefaultValuedAttr:$tile_interchange); - let results = (outs PDL_Operation:$transformed, - Variadic:$loops); + let results = (outs TransformHandleTypeInterface:$transformed, + Variadic:$loops); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; @@ -177,10 +178,11 @@ This operation reads the containing op handle. }]; - let arguments = (ins PDL_Operation:$producer_op, - PDL_Operation:$containing_op); - let results = (outs PDL_Operation:$fused_op); - let assemblyFormat = "$producer_op `into` $containing_op attr-dict"; + let arguments = (ins TransformHandleTypeInterface:$producer_op, + TransformHandleTypeInterface:$containing_op); + let results = (outs TransformHandleTypeInterface:$fused_op); + let assemblyFormat = "$producer_op `into` $containing_op attr-dict " + " `:` functional-type(operands, results)"; let builders = [ OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)> diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -330,34 +330,6 @@ return success(); } -/// Parse a tiling-like operation that returns the tiled op as well as the -/// created tile loops. The function counts the non-zero tile sizes to compute -/// the number of results. -static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, - StringRef sizesAttrName) { - OpAsmParser::UnresolvedOperand targetOperand; - SMLoc opLoc = parser.getCurrentLocation(); - if (parser.parseOperand(targetOperand) || - parser.parseOptionalAttrDict(result.attributes)) - return failure(); - Attribute sizesAttr = result.attributes.get(sizesAttrName); - if (!sizesAttr) - return parser.emitError(opLoc) - << "expected '" << sizesAttrName << "' attribute"; - auto sizesArrayAttr = sizesAttr.dyn_cast(); - if (!sizesArrayAttr) - return parser.emitError(opLoc) - << "'" << sizesAttrName << "' attribute must be an array"; - Type pdlOpType = parser.getBuilder().getType(); - size_t numExpectedLoops = - sizesArrayAttr.size() - - llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); - result.addTypes(SmallVector(numExpectedLoops + 1, pdlOpType)); - if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) - return failure(); - return success(); -} - DiagnosedSilenceableFailure transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { @@ -385,15 +357,34 @@ ParseResult transform::FuseOp::parse(OpAsmParser &parser, OperationState &result) { - return parseTileLikeOp( - parser, result, - transform::FuseOp::getTileSizesAttrName(result.name).getValue()); + OpAsmParser::UnresolvedOperand targetOperand; + if (parser.parseOperand(targetOperand) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + FunctionType trailingType; + SMLoc typeLoc; + if (parser.getCurrentLocation(&typeLoc) || + parser.parseColonType(trailingType)) { + return failure(); + } + if (trailingType.getNumInputs() != 1) + return parser.emitError(typeLoc) << "expected one input type"; + + result.addTypes(trailingType.getResults()); + if (parser.resolveOperand(targetOperand, trailingType.getInput(0), + result.operands)) + return failure(); + return success(); } void transform::FuseOp::print(OpAsmPrinter &p) { p << ' '; p << getTarget(); p.printOptionalAttrDict((*this)->getAttrs()); + p << " : "; + p.printFunctionalType(TypeRange(getOperand().getType()), + getResults().getTypes()); } LogicalResult transform::FuseOp::verify() { @@ -405,6 +396,12 @@ return emitOpError() << "expects interchange to be a permutation, found " << getTileInterchange(); } + + SmallVector sizes = extractFromI64ArrayAttr(getTileSizes()); + size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0); + if (numExpectedLoops != getNumResults() - 1) + return emitOpError() << "expects " << numExpectedLoops << " loop results"; + return success(); } diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -183,7 +183,7 @@ } transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.structured.decompose %0 +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -42,12 +42,13 @@ func.func @dummy3() { return } transform.sequence failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!pdl.operation) -> !pdl.operation + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.fill"> + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.fill is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op } } @@ -85,12 +86,13 @@ } transform.sequence failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!pdl.operation) -> !pdl.operation + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 : (!transform.any_op) -> !transform.op<"tensor.empty"> + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // tensor.empty is not tileable. The op is cloned and fused. transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> !transform.any_op } } @@ -131,12 +133,13 @@ } transform.sequence failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!pdl.operation) -> !pdl.operation + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.fill"> + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.fill is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op } } @@ -179,12 +182,13 @@ } transform.sequence failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!pdl.operation) -> !pdl.operation + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op // linalg.fill is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.any_op, !transform.any_op) -> !transform.any_op } } @@ -239,11 +243,12 @@ } transform.sequence failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!pdl.operation) -> !pdl.operation + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.generic is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 + : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op } } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -16,9 +16,10 @@ } transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!pdl.operation) -> !pdl.operation +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) } // ----- @@ -43,11 +44,11 @@ } transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!pdl.operation) -> !pdl.operation +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} - %loop = transform.cast %loops#0 : !pdl.operation to !transform.op<"scf.for"> - transform.loop.peel %loop : (!transform.op<"scf.for">) -> !pdl.operation + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> !transform.any_op } // ----- @@ -86,10 +87,12 @@ } transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]} - %2, %loops_2 = transform.structured.tile %1 [0, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %2, %loops_2 = transform.structured.tile %1 [0, 4] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) } // ----- @@ -110,7 +113,8 @@ } transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!pdl.operation) -> !pdl.operation +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) } diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir @@ -51,6 +51,7 @@ // Fuse all producers. transform.structured.fuse_into_containing_op %producers into %forall_op + : (!pdl.operation, !pdl.operation) -> !transform.any_op } } @@ -109,5 +110,6 @@ // Fuse all producers. transform.structured.fuse_into_containing_op %reversed_producers into %forall_op + : (!pdl.operation, !pdl.operation) -> !transform.any_op } }