diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -19,6 +20,13 @@ include "mlir/IR/OpBase.td" include "mlir/IR/RegionKindInterface.td" +// This is roughly similar to OpFoldResult assuming the handle produces a single +// value in the payload IR. +def TransformParamTypeOrAnyHandle : Type< + Or<[TransformHandleTypeInterface.predicate, + Transform_ParamType.predicate]>, + "transform 'param' type or any handle type">; + //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// @@ -311,27 +319,41 @@ ```mlir %sz1, %sz2, %split = structured.multitile_sizes %target { target_size = 10, dimension = 1 } + : !transform.any_op, !transform.param, + !transform.param, !transform.param %low, %high = structured.split %target after %split { dimension = 1 } - %tiled_low = structured.tile %low [0, %sz1] - %tiled_high = structured.tile %high [0, %sz2] - %common = merge_handles %tiled_low, %tiled_high + : !transform.any_op, !transform.param + %tiled_low, %loop1 = structured.tile %low [0, %sz1] + : (!transform.any_op, !transform.param) + -> (!transform.any_op, !transform.any_op) + %tiled_high, %loop2 = structured.tile %high [0, %sz2] + : (!transform.any_op, !transform.param) + -> (!transform.any_op, !transform.any_op) + %common = merge_handles %tiled_low, %tiled_high : !transform.any_op %sz3, %sz4, %split = structured.multitile_size %target { target_size = 42, dimension = 0 } + : !transform.any_op, !transform.any_op, + !transform.any_op, !transform.any_op %sz3r, %sz4r, %splitr = replicate num(%common) %sz3, %sz4, %splitr + : !transform.any_op, !transform.any_op, !transform.any_op structured.split %common after %splitr { dimension = 0 } + : !transform.any_op, !transform.any_op // ... ``` }]; - let arguments = (ins PDL_Operation:$target, + let arguments = (ins TransformHandleTypeInterface:$target, I64Attr:$dimension, I64Attr:$target_size, DefaultValuedAttr:$divisor); - let results = (outs PDL_Operation:$low_size, - PDL_Operation:$high_size, - PDL_Operation:$split_point); - let assemblyFormat = "$target attr-dict"; + let results = (outs TransformParamTypeOrAnyHandle:$low_size, + TransformParamTypeOrAnyHandle:$high_size, + TransformParamTypeOrAnyHandle:$split_point); + let hasVerifier = 1; + let assemblyFormat = + "$target attr-dict `:` custom(" + "type($target), type($low_size), type($high_size), type($split_point))"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -524,11 +546,12 @@ iteration space indices. }]; - let arguments = (ins PDL_Operation:$target, + let arguments = (ins TransformHandleTypeInterface:$target, I64Attr:$dimension, - Optional:$dynamic_split_point, + Optional:$dynamic_split_point, I64Attr:$static_split_point); - let results = (outs PDL_Operation:$first, PDL_Operation:$second); + let results = (outs TransformHandleTypeInterface:$first, + TransformHandleTypeInterface:$second); let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } @@ -962,17 +985,19 @@ produces a definite failure. }]; - let arguments = (ins PDL_Operation:$target, - Variadic:$dynamic_sizes, + let arguments = (ins TransformHandleTypeInterface:$target, + Variadic:$dynamic_sizes, DefaultValuedOptionalAttr:$static_sizes, DefaultValuedOptionalAttr:$interchange); - let results = (outs PDL_Operation:$tiled_linalg_op, - Variadic:$loops); + let results = (outs TransformHandleTypeInterface:$tiled_linalg_op, + Variadic:$loops); let builders = [ - OpBuilder<(ins "Value":$target, + OpBuilder<(ins "TypeRange":$loopTypes, + "Value":$target, "ArrayRef":$staticTileSizes, CArg<"ArrayRef", "{}">:$interchange)>, - OpBuilder<(ins "Value":$target, + OpBuilder<(ins "TypeRange":$loopTypes, + "Value":$target, "ArrayRef":$mixedTileSizes, CArg<"ArrayRef", "{}">:$interchange)> ]; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -412,15 +412,23 @@ ArrayRef allShapeSizes, ArrayRef allTileSizes); -/// A description of a multi-size tiling comprising tile sizes and numbers of -/// tiles, expressed as Values which may or may not be constant. Multi-size -/// currently means two-size. -struct MultiSizeSpecification { +namespace detail { +template +struct MultiSizeSpecificationBase { /// Tile sizes. - Value lowTileSize, highTileSize; + T lowTileSize, highTileSize; /// Number of tiles associated with each size. - Value lowTripCount, highTripCount; + T lowTripCount, highTripCount; }; +} // namespace detail + +/// A description of a multi-size tiling comprising tile sizes and numbers of +/// tiles, expressed as Values which may or may not be constant. Multi-size +/// currently means two-size. +struct MultiSizeSpecification + : public detail::MultiSizeSpecificationBase {}; +struct StaticMultiSizeSpecification + : public detail::MultiSizeSpecificationBase {}; /// Emits the IR computing the multi-sized tiling specification with two tile /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such @@ -453,6 +461,9 @@ computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions = true); +FailureOr +computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, + int64_t divisor); /// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying /// tiling by `numThreads`. diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -149,9 +149,13 @@ underlying type. A user of the value can assume that the parameter has been verified. }]; - } +def Transform_AnyHandleOrParamType + : Type, + "any transform handle or parameter">; + def FunctionalStyleTransformOpTrait : NativeOpTrait<"FunctionalStyleTransformOpTrait"> { let cppNamespace = "::mlir::transform"; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -350,8 +350,8 @@ }]; let arguments = (ins TransformHandleTypeInterface:$pattern, - Variadic:$handles); - let results = (outs Variadic:$replicated); + Variadic:$handles); + let results = (outs Variadic:$replicated); let assemblyFormat = "`num` `(` $pattern `)` $handles attr-dict `:` " "type($pattern) `,` type($handles)"; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/Dialect/Transform/Utils/Utils.h" #include "mlir/IR/BuiltinTypes.h" @@ -707,9 +708,65 @@ // MultiTileSizesOp //===---------------------------------------------------------------------===// +static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, + Type targetType, Type lowSizeType, Type, + Type) { + printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType}); +} + +static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, + Type &targetType, Type &lowSizeType, + Type &highSizeType, + Type &splitPointType) { + FunctionType funcType; + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + if (failed(parser.parseType(funcType))) + return failure(); + + if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) { + parser.emitError(typeLoc) << "expects a trailing functional type with one " + "argument and one result"; + } + targetType = funcType.getInput(0); + lowSizeType = highSizeType = splitPointType = funcType.getResult(0); + + return success(); +} + DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( LinalgOp target, transform::ApplyToEachResultList &results, TransformState &state) { + if (getLowSize().getType().isa()) { + if (target.hasDynamicShape()) { + results.assign( + ArrayRef({Attribute(), Attribute(), Attribute()})); + auto diag = emitSilenceableError() + << "cannot compute parametric tile sizes for dynamically " + "shaped payload op"; + diag.attachNote(target->getLoc()) << "payload op"; + return diag; + } + + FailureOr spec = computeStaticMultiTileSizes( + target, getDimension(), getTargetSize(), getDivisor()); + if (failed(spec)) { + results.assign( + ArrayRef({Attribute(), Attribute(), Attribute()})); + return emitSilenceableError() + << "failed to compute multi-size tiling sizes"; + } + + Builder builder(target.getContext()); + results.assign(llvm::map_range( + ArrayRef({spec->lowTileSize, spec->highTileSize, + spec->lowTileSize * spec->lowTripCount}), + [&builder, this](int64_t value) { + return builder.getIntegerAttr( + getLowSize().getType().cast().getType(), value); + })); + return DiagnosedSilenceableFailure::success(); + } + OpBuilder builder(target.getContext()); builder.setInsertionPoint(target); OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); @@ -740,7 +797,18 @@ SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); producesHandle(getResults(), effects); - modifiesPayload(effects); + if (getLowSize().getType().isa()) + onlyReadsPayload(effects); + else + modifiesPayload(effects); +} + +LogicalResult transform::MultiTileSizesOp::verify() { + if (getLowSize().getType() != getHighSize().getType() || + getLowSize().getType() != getSplitPoint().getType()) { + return emitOpError() << "expects all results type to be the same"; + } + return success(); } //===---------------------------------------------------------------------===// @@ -1012,17 +1080,23 @@ splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { auto diag = DiagnosedSilenceableFailure::success(); - splitPoints = llvm::to_vector(llvm::map_range( - state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { - if (op->getNumResults() != 1 || - !op->getResult(0).getType().isIndex()) { - diag = emitSilenceableError() - << "expected dynamic split point handle to point to a " - "single-result index-typed op"; - diag.attachNote(op->getLoc()) << "dynamic split point"; - } - return OpFoldResult(op->getResult(0)); - })); + if (getDynamicSplitPoint().getType().isa()) { + splitPoints = llvm::to_vector(llvm::map_range( + state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { + if (op->getNumResults() != 1 || + !op->getResult(0).getType().isIndex()) { + diag = emitSilenceableError() + << "expected dynamic split point handle to point to a " + "single-result index-typed op"; + diag.attachNote(op->getLoc()) << "dynamic split point"; + } + return OpFoldResult(op->getResult(0)); + })); + } else { + splitPoints = llvm::to_vector( + llvm::map_range(state.getParams(getDynamicSplitPoint()), + [](Attribute attr) { return OpFoldResult(attr); })); + } if (diag.isSilenceableFailure()) { results.set(getFirst().cast(), {}); results.set(getSecond().cast(), {}); @@ -1086,11 +1160,7 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; IntegerAttr staticSplitPoint; - auto pdlOperationType = - pdl::OperationType::get(parser.getBuilder().getContext()); - if (parser.parseOperand(target) || - parser.resolveOperand(target, pdlOperationType, result.operands) || - parser.parseKeyword("after")) + if (parser.parseOperand(target) || parser.parseKeyword("after")) return failure(); OptionalParseResult dynamicPointParseResult = @@ -1102,9 +1172,19 @@ staticSplitPoint = parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); - } else { - if (failed(*dynamicPointParseResult) || - parser.resolveOperand(dynamicSplitPoint, pdlOperationType, + } + + Type targetType; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(targetType) || + parser.resolveOperand(target, targetType, result.operands)) { + return failure(); + } + if (dynamicPointParseResult.has_value()) { + Type splitPointType; + if (failed(*dynamicPointParseResult) || parser.parseComma() || + parser.parseType(splitPointType) || + parser.resolveOperand(dynamicSplitPoint, splitPointType, result.operands)) { return failure(); } @@ -1116,10 +1196,7 @@ result.addAttribute( SplitOp::getStaticSplitPointAttrName(result.name).getValue(), staticSplitPoint); - if (failed(parser.parseOptionalAttrDict(result.attributes))) - return failure(); - - result.addTypes({pdlOperationType, pdlOperationType}); + result.addTypes({targetType, targetType}); return success(); } @@ -1133,6 +1210,9 @@ printer << " "; printer.printOptionalAttrDict(getOperation()->getAttrs(), {getStaticSplitPointAttrName()}); + printer << " : " << getTarget().getType(); + if (staticSplitSize == ShapedType::kDynamic) + printer << ", " << getDynamicSplitPoint().getType(); } LogicalResult SplitOp::verify() { @@ -1295,10 +1375,12 @@ //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// + void transform::TileOp::build(OpBuilder &builder, OperationState &result, - Value target, ArrayRef staticTileSizes, + TypeRange loopTypes, Value target, + ArrayRef staticTileSizes, ArrayRef interchange) { - return build(builder, result, + return build(builder, result, loopTypes, /*target=*/target, /*mixedTileSizes=*/ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), @@ -1306,7 +1388,7 @@ } void transform::TileOp::build(OpBuilder &builder, OperationState &result, - Value target, + TypeRange loopTypes, Value target, ArrayRef mixedTileSizes, ArrayRef interchange) { SmallVector staticTileSizes; @@ -1315,11 +1397,19 @@ // Call the default builder which sets up the proper operands segment sizes // attributes for multiple variadic operands. In the absence of this, horrible // bugs ensue. - MLIRContext *ctx = builder.getContext(); - auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); - build(builder, result, - /*resultTypes=*/TypeRange{operationType, operationType}, + unsigned numExpectedLoops = + staticTileSizes.size() - llvm::count(staticTileSizes, 0); + SmallVector resultTypes; + resultTypes.reserve(numExpectedLoops); + assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) && + "expected one loop type or as many as loops"); + if (loopTypes.size() == 1) + resultTypes.append(numExpectedLoops, loopTypes[0]); + else + llvm::append_range(resultTypes, loopTypes); + build(builder, result, /*tiled_linalg_op=*/target.getType(), + /*loops=*/resultTypes, /*target=*/target, /*dynamic_sizes=*/dynamicTileSizes, /*static_sizes=*/staticTileSizesAttr, @@ -1333,18 +1423,44 @@ ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; + SmallVector> paramSizes; dynamicSizeProducers.reserve(getDynamicSizes().size()); - for (Value dynamicSizeProducerHandle : getDynamicSizes()) { - dynamicSizeProducers.push_back( - state.getPayloadOps(dynamicSizeProducerHandle)); + paramSizes.reserve(getDynamicSizes().size()); + for (Value transformValue : getDynamicSizes()) { + if (transformValue.getType().isa()) { + dynamicSizeProducers.push_back({}); + ArrayRef params = state.getParams(transformValue); + paramSizes.push_back( + llvm::to_vector(llvm::map_range(params, [](Attribute attr) { + return attr.cast().getValue().getSExtValue(); + }))); + + if (paramSizes.back().size() != targets.size()) { + for (OpResult r : getResults()) + transformResults.set(r, {}); + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "expected as many parameter values (" + << dynamicSizeProducers.back().size() << ") as target ops (" + << targets.size() << ")"; + diag.attachNote(transformValue.getLoc()) << "for this parameter"; + return diag; + } + + continue; + } + paramSizes.push_back({}); + dynamicSizeProducers.push_back(state.getPayloadOps(transformValue)); if (dynamicSizeProducers.back().size() != targets.size()) { + for (OpResult r : getResults()) + transformResults.set(r, {}); DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected as many dynamic size-producing operations (" << dynamicSizeProducers.back().size() << ") as target ops (" << targets.size() << ")"; - diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; + diag.attachNote(transformValue.getLoc()) << "for this handle"; return diag; } @@ -1352,11 +1468,14 @@ if (op->getNumResults() == 1 && op->getResult(0).getType().isa()) continue; + + for (OpResult r : getResults()) + transformResults.set(r, {}); DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " "with a single index-type result"; diag.attachNote(op->getLoc()) << "size producer op"; - diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; + diag.attachNote(transformValue.getLoc()) << "for this handle"; return diag; } } @@ -1385,9 +1504,19 @@ if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( getLoc(), attr.cast().getInt())); - } else { + continue; + } + ArrayRef dynamicSizes = + dynamicSizeProducers[dynamicIdx]; + ArrayRef params = paramSizes[dynamicIdx]; + ++dynamicIdx; + assert((dynamicSizes.empty() ^ params.empty()) && + "expected either dynamic sizes or parameters"); + if (!params.empty()) { sizes.push_back( - dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); + b.create(getLoc(), params[index])); + } else { + sizes.push_back(dynamicSizes[index]->getResult(0)); } } return sizes; @@ -1469,20 +1598,34 @@ OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; DenseI64ArrayAttr staticSizes; - auto pdlOperationType = pdl::OperationType::get(parser.getContext()); - if (parser.parseOperand(target) || - parser.resolveOperand(target, pdlOperationType, result.operands) || + FunctionType functionalType; + llvm::SMLoc operandLoc; + if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) || parseDynamicIndexList(parser, dynamicSizes, staticSizes) || - parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands)) + parseOptionalInterchange(parser, result) || + parser.parseColonType(functionalType)) return ParseResult::failure(); - // Parse optional interchange. - if (failed(parseOptionalInterchange(parser, result))) - return ParseResult::failure(); - result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); - result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); + if (functionalType.getNumResults() != numExpectedLoops + 1) { + return parser.emitError(parser.getNameLoc()) + << "expected " << (numExpectedLoops + 1) << " result type(s)"; + } + if (functionalType.getNumInputs() != dynamicSizes.size() + 1) { + return parser.emitError(operandLoc) + << "expected " << dynamicSizes.size() + 1 << " operand type(s)"; + } + if (parser.resolveOperand(target, functionalType.getInputs().front(), + result.operands) || + parser.resolveOperands(dynamicSizes, + functionalType.getInputs().drop_front(), + operandLoc, result.operands)) { + return failure(); + } + + result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); + result.addTypes(functionalType.getResults()); return success(); } @@ -1490,6 +1633,8 @@ p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); printOptionalInterchange(p, getInterchange()); + p << " : "; + p.printFunctionalType(getOperands().getTypes(), getResults().getTypes()); } void transform::TileOp::getEffects( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -117,6 +117,32 @@ b.getStringAttr("expected strictly positive tile size and divisor")); } +FailureOr +mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, + int64_t targetSize, int64_t divisor) { + assert(!op.hasDynamicShape() && + "cannot compute static multi-tile sizes for an op with dynamic shape"); + assert(targetSize > 0 && "target size must be non-negative"); + assert(divisor > 0 && "divisor must be non-negative"); + assert(dimension < op.getNumLoops() && "dimension overflow"); + + StaticMultiSizeSpecification spec; + int64_t tripCount = op.getStaticLoopRanges()[dimension]; + int64_t a = tripCount / divisor; + int64_t t = (targetSize + divisor - 1) / divisor; + int64_t totalTripCount = (a + t - 1) / t; + spec.lowTileSize = (a / totalTripCount) * divisor; + spec.highTileSize = spec.lowTileSize + divisor; + spec.highTripCount = a % totalTripCount; + spec.lowTripCount = totalTripCount - spec.highTripCount; + if (spec.lowTileSize * spec.lowTripCount + + spec.highTileSize * spec.highTripCount != + tripCount) { + return failure(); + } + return spec; +} + FailureOr mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -561,12 +561,23 @@ unsigned numRepetitions = state.getPayloadOps(getPattern()).size(); for (const auto &en : llvm::enumerate(getHandles())) { Value handle = en.value(); - ArrayRef current = state.getPayloadOps(handle); - SmallVector payload; - payload.reserve(numRepetitions * current.size()); - for (unsigned i = 0; i < numRepetitions; ++i) - llvm::append_range(payload, current); - results.set(getReplicated()[en.index()].cast(), payload); + if (handle.getType().isa()) { + ArrayRef current = state.getPayloadOps(handle); + SmallVector payload; + payload.reserve(numRepetitions * current.size()); + for (unsigned i = 0; i < numRepetitions; ++i) + llvm::append_range(payload, current); + results.set(getReplicated()[en.index()].cast(), payload); + } else { + assert(handle.getType().isa() && + "expected param type"); + ArrayRef current = state.getParams(handle); + SmallVector params; + params.reserve(numRepetitions * current.size()); + for (unsigned i = 0; i < numRepetitions; ++i) + llvm::append_range(params, current); + results.setParams(getReplicated()[en.index()].cast(), params); + } } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -51,13 +51,13 @@ def _get_dense_int64_array_attr( values: Sequence[int]) -> DenseI64ArrayAttr: - """Creates a dense integer array from a sequence of integers. + """Creates a dense integer array from a sequence of integers. Expects the thread-local MLIR context to have been set by the context manager. """ - if values is None: - return DenseI64ArrayAttr.get([]) - return DenseI64ArrayAttr.get(values) + if values is None: + return DenseI64ArrayAttr.get([]) + return DenseI64ArrayAttr.get(values) def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, @@ -141,6 +141,7 @@ """Specialization for MultitileSizesOp class.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, dimension: Union[int, IntegerAttr], @@ -149,9 +150,9 @@ loc=None, ip=None): super().__init__( - pdl.OperationType.get(), - pdl.OperationType.get(), - pdl.OperationType.get(), + result_type, + result_type, + result_type, _get_op_result_or_value(target), dimension=_get_int64_attr(dimension), target_size=_get_int64_attr(target_size), @@ -223,10 +224,9 @@ static_split_point = _get_int64_attr(ShapedType.get_dynamic_size()) dynamic_split_point = _get_op_result_or_value(split_point) - pdl_operation_type = pdl.OperationType.get() super().__init__( - pdl_operation_type, - pdl_operation_type, + target.type, + target.type, _get_op_result_or_value(target), dimension=dimension, static_split_point=static_split_point, @@ -239,6 +239,7 @@ """Specialization for TileOp class.""" def __init__(self, + loop_types: Union[Type, List[Type]], target: Union[Operation, Value], *, sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, @@ -246,9 +247,6 @@ interchange: OptionalIntList = None, loc=None, ip=None): - pdl_operation_type = pdl.OperationType.get() - i64_type = IntegerType.get_signless(64) - if sizes is None: sizes = [] @@ -268,11 +266,14 @@ num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) super().__init__( - pdl_operation_type, [pdl_operation_type] * num_loops, + target.type, + ([loop_types] * + num_loops) if isinstance(loop_types, Type) else loop_types, _get_op_result_or_value(target), dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, - interchange=_get_dense_int64_array_attr(interchange) if interchange else None, + interchange=_get_dense_int64_array_attr(interchange) + if interchange else None, loc=loc, ip=ip) diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -15,7 +15,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op - %1, %loops:3 = transform.structured.tile %0 [2, 2, 2] + %1, %loops:3 = transform.structured.tile %0 [2, 2, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %2 transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir --- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir +++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir @@ -1,23 +1,29 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter --scf-for-loop-canonicalization --canonicalize %s | FileCheck %s +// RUN: mlir-opt --test-transform-dialect-interpreter --scf-for-loop-canonicalization --canonicalize --split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s --check-prefix=NOCANON // This implements a 2D multisize tiling with target sizes [3, 10]. transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} - %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} - %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } - %3:2 = transform.structured.tile %2#0 [%1#0] - %4:2 = transform.structured.tile %2#1 [%1#1] + %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!pdl.operation) -> !pdl.operation + %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!pdl.operation) -> !pdl.operation + %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !pdl.operation, !pdl.operation + %3:2 = transform.structured.tile %2#0 [%1#0] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation) + %4:2 = transform.structured.tile %2#1 [%1#1] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation) %5 = merge_handles %3#0, %4#0 : !pdl.operation %tt:3 = replicate num(%5) %t#0, %t#1, %t#2 : !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation - %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } - transform.structured.tile %6#0 [0, %tt#0] - transform.structured.tile %6#1 [0, %tt#1] + %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !pdl.operation, !pdl.operation + transform.structured.tile %6#0 [0, %tt#0] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation) + transform.structured.tile %6#1 [0, %tt#1] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation) } func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 +// Without canonicalization, tile sizes are computed dynamically as affine maps. +// NOCANON-LABEL: @two_d +// NOCANON-COUNT-8: affine.apply +// NOCANON: scf.for + // CHECK-LABEL: @two_d // CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32> func.func @two_d(%arg0: tensor<10x34xf32>, @@ -93,3 +99,96 @@ return %0 : tensor<10x34xf32> } + +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!pdl.operation) -> !transform.param + %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!pdl.operation) -> !transform.param + %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !pdl.operation, !transform.param + %3:2 = transform.structured.tile %2#0 [%1#0] : (!pdl.operation, !transform.param) -> (!pdl.operation, !pdl.operation) + %4:2 = transform.structured.tile %2#1 [%1#1] : (!pdl.operation, !transform.param) -> (!pdl.operation, !pdl.operation) + %5 = merge_handles %3#0, %4#0 : !pdl.operation + %tt:3 = replicate num(%5) %t#0, %t#1, %t#2 : !pdl.operation, !transform.param, !transform.param, !transform.param + %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !pdl.operation, !transform.param + transform.structured.tile %6#0 [0, %tt#0] : (!pdl.operation, !transform.param) -> (!pdl.operation, !pdl.operation) + transform.structured.tile %6#1 [0, %tt#1] : (!pdl.operation, !transform.param) -> (!pdl.operation, !pdl.operation) +} + +func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 + +// Even without canonicalization, tile sizes can be computed statically thanks +// to parameters. +// NOCANON-LABEL: @two_d +// NOCANON-NOT: affine.apply +// NOCANON: scf.for + +// CHECK-LABEL: @two_d_param +// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32> +func.func @two_d_param(%arg0: tensor<10x34xf32>, + %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> { + %0 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"] + } + ins(%arg0: tensor<10x34xf32>) + outs(%arg1: tensor<10x34xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<10x34xf32> + + // CHECK: %[[SLICE_1_IN:.+]] = tensor.extract_slice %[[IN]][0, 0] [4, 34] [1, 1] + // CHECK: %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1] + // CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]]) + // CHECK: %[[OUTSLICE_1_IN:.+]] = tensor.extract_slice %[[SLICE_1_IN]][%[[I1]], 0] [2, 34] [1, 1] + // CHECK: %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1] + + // CHECK: %[[SLICE_2_IN:.+]] = tensor.extract_slice %[[OUTSLICE_1_IN]][0, 0] [2, 16] [1, 1] + // CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] + // CHECK: %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]]) + // CHECK: %[[INSLICE_2:.+]] = tensor.extract_slice %[[SLICE_2_IN]][0, %[[I2]]] [2, 8] [1, 1] + // CHECK: %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1] + // CHECK: %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>) + // CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]] + // CHECK: scf.yield %[[RESPARTIAL]] + + // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] + // CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1] + // CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]]) + // CHECK-COUNT-2: tensor.extract_slice + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>) + // CHECK: tensor.insert_slice + // CHECK: scf.yield + // CHECK: %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]] + // CHECK: %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]] + // CHECK: scf.yield %[[INSERTED_3]] + + // CHECK: tensor.insert_slice + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK-COUNT-2: tensor.extract_slice + // CHECK: scf.for + // CHECK-COUNT-2: tensor.extract_slice + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>) + // CHECK: tensor.insert_slice + // CHECK: scf.yield + // CHECK: tensor.insert_slice + // CHECK: tensor.extract_slice + // CHECK: scf.for + // CHECK-COUNT-2: tensor.extract_slice + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>) + // CHECK: tensor.insert_slice + // CHECK: scf.yield + // CHECK-COUNT-2: tensor.insert_slice + // CHECK: scf.yield + // CHECK: %[[RESULT:.+]] = tensor.insert_slice + // CHECK: return %[[RESULT]] + + return %0 : tensor<10x34xf32> +} diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir --- a/mlir/test/Dialect/Linalg/promotion_options.mlir +++ b/mlir/test/Dialect/Linalg/promotion_options.mlir @@ -37,6 +37,6 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] + %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %2 = transform.structured.promote %1 { operands_to_promote = [0, 2], force_full_tiles = [false, false], use_full_tiles_by_default } } diff --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir --- a/mlir/test/Dialect/Linalg/tile-conv.mlir +++ b/mlir/test/Dialect/Linalg/tile-conv.mlir @@ -12,7 +12,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.conv_2d"]} in %arg1 - %1, %loop:2 = transform.structured.tile %0 [2, 3] + %1, %loop:2 = transform.structured.tile %0 [2, 3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) } // CHECK: func @conv diff --git a/mlir/test/Dialect/Linalg/tile-indexed.mlir b/mlir/test/Dialect/Linalg/tile-indexed.mlir --- a/mlir/test/Dialect/Linalg/tile-indexed.mlir +++ b/mlir/test/Dialect/Linalg/tile-indexed.mlir @@ -14,7 +14,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1, %loop = transform.structured.tile %0 [10] + %1, %loop = transform.structured.tile %0 [10] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) } // TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)> @@ -44,7 +44,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1, %loop:2 = transform.structured.tile %0 [10, 25] + %1, %loop:2 = transform.structured.tile %0 [10, 25] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) } // TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)> diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -30,7 +30,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1, %loops:3 = transform.structured.tile %0 [2, 3, 4] + %1, %loops:3 = transform.structured.tile %0 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) } // ----- @@ -61,7 +61,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1, %loops:3 = transform.structured.tile %0 [2, 3, 4] + %1, %loops:3 = transform.structured.tile %0 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) } // CHECK-LABEL: func @generic_op_tensors @@ -132,5 +132,5 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1, %loops:3 = transform.structured.tile %0 [2, 3, 4] + %1, %loops:3 = transform.structured.tile %0 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -89,5 +89,5 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]} - %2, %loops_2 = transform.structured.tile %1 [0, 4] + %2, %loops_2 = transform.structured.tile %1 [0, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) } diff --git a/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir b/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir --- a/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir @@ -1,11 +1,11 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics | FileCheck %s // CHECK-DAG: #[[$MAP13:.+]] = affine_map<() -> (13)> transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 } + transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 } : (!pdl.operation) -> !pdl.operation } // CHECK-LABEL: @multitile_sizes_static @@ -29,7 +29,34 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 } + %low_tile, %high_tile, %split_point = + transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 } + : (!pdl.operation) -> !transform.param + // expected-remark @below {{2 : i64}} + transform.test_print_param %low_tile : !transform.param + // expected-remark @below {{3 : i64}} + transform.test_print_param %high_tile : !transform.param + // expected-remark @below {{4 : i64}} + transform.test_print_param %split_point : !transform.param +} + +// CHECK-LABEL: @multitile_sizes_static_gen +func.func @multitile_sizes_static_gen( + %arg0: tensor<13x34xf32>, %arg1: tensor<34x42xf32>, %arg2: tensor<13x42xf32>) + -> tensor<13x42xf32> { + %0 = linalg.matmul ins(%arg0, %arg1: tensor<13x34xf32>, tensor<34x42xf32>) + outs(%arg2: tensor<13x42xf32>) + -> tensor<13x42xf32> + + return %0 : tensor<13x42xf32> +} + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 } : (!pdl.operation) -> !pdl.operation } // CHECK: #[[$MAP_A:.+]] = affine_map<()[s0] -> ([[A_IMPL:s0 floordiv 2]])> @@ -64,3 +91,24 @@ return %0 : tensor } + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + // expected-error @below {{cannot compute parametric tile sizes for dynamically shaped payload op}} + transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 } + : (!pdl.operation) -> !transform.param +} + +func.func @multitile_sizes_dynamic_gen( + %arg0: tensor, %arg1: tensor, %arg2: tensor) + -> tensor { + // expected-note @below {{payload op}} + %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) + outs(%arg2: tensor) + -> tensor + + return %0 : tensor +} diff --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir @@ -21,6 +21,6 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1, %loops = transform.structured.tile %0 [10, 0, 0] + %1, %loops = transform.structured.tile %0 [10, 0, 0] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) %2 = transform.structured.scalarize %1 } diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir @@ -3,7 +3,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1:2 = transform.structured.split %0 after 42 { dimension = 0 } + %1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !pdl.operation } func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 @@ -75,7 +75,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1 = transform.structured.match ops{["func.call"]} in %arg1 - transform.structured.split %0 after %1 { dimension = 0 } + transform.structured.split %0 after %1 { dimension = 0 } : !pdl.operation, !pdl.operation } func.func private @get_size() -> index @@ -122,8 +122,8 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1:2 = transform.structured.split %0 after 4 { dimension = 0} - %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } + %1:2 = transform.structured.split %0 after 4 { dimension = 0 } : !pdl.operation + %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !pdl.operation } func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 @@ -189,7 +189,7 @@ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1 = transform.structured.match ops{["func.call"]} in %arg1 // expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}} - transform.structured.split %0 after %1 { dimension = 0 } + transform.structured.split %0 after %1 { dimension = 0 } : !pdl.operation, !pdl.operation } func.func private @get_size() -> i64 @@ -215,7 +215,7 @@ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1 = transform.structured.match ops{["func.call"]} in %arg1 // expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}} - transform.structured.split %0 after %1 { dimension = 0 } + transform.structured.split %0 after %1 { dimension = 0 } : !pdl.operation, !pdl.operation } func.func private @get_size() -> i64 @@ -238,7 +238,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.return"]} in %arg1 // expected-error @below {{only applies to structured ops}} - transform.structured.split %0 after 16 { dimension = 1 } + transform.structured.split %0 after 16 { dimension = 1 } : !pdl.operation } func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { @@ -252,7 +252,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 // expected-error @below {{dimension 1 does not exist in target op}} - transform.structured.split %0 after 16 { dimension = 1 } + transform.structured.split %0 after 16 { dimension = 1 } : !pdl.operation } func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -1,9 +1,9 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1, %loops:3 = transform.structured.tile %0 [4, 4, 4] + %1, %loops:3 = transform.structured.tile %0 [4, 4, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) } // CHECK-LABEL: func @tile_linalg_matmul( @@ -40,7 +40,7 @@ ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 %1 = transform.structured.match ops{["func.call"]} in %arg1 - %2, %loops:3 = transform.structured.tile %0 [%1, %1, 4] + %2, %loops:3 = transform.structured.tile %0 [%1, %1, 4] : (!pdl.operation, !pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) } func.func private @get_dynamic_tile_size() -> index @@ -73,3 +73,53 @@ // CHECK: return %[[TD0]] : tensor<128x128xf32> return %0 : tensor<128x128xf32> } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + // expected-note @below {{for this parameter}} + %1 = transform.test_produce_integer_param_with_type i64 : !transform.param + // expected-error @below {{expected as many parameter values (0) as target ops (2)}} + transform.structured.tile %0 [%1, %1, %1] + : (!pdl.operation, !transform.param, !transform.param, !transform.param) + -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) +} + +func.func @tile_linalg_matmul( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + %1 = linalg.matmul ins(%0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + // expected-note @below {{for this handle}} + %1 = transform.structured.match ops{["arith.constant"]} in %arg1 + // expected-error @below {{expected as many dynamic size-producing operations (0) as target ops (2)}} + transform.structured.tile %0 [%1, %1, 1] + : (!pdl.operation, !pdl.operation, !pdl.operation) + -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) +} + +func.func @tile_linalg_matmul( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> (tensor<128x128xf32>, tensor<128x128xf32>) { + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + %1 = linalg.matmul ins(%0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> +} diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir --- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir @@ -45,3 +45,12 @@ // expected-error@below {{'transform.structured.interchange' op attribute 'iterator_interchange' failed to satisfy constraint: i64 dense array attribute whose value is non-negative}} transform.structured.interchange %arg0 iterator_interchange = [-3, 1] } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{expects all results type to be the same}} + "transform.structured.multitile_sizes"(%arg0) { target_size = 3, divisor = 2, dimension = 0 } + : (!pdl.operation) -> (!transform.param, !transform.param, !transform.param) +} diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -3,7 +3,13 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile - %0, %1:2 = transform.structured.tile %arg0 [2, 0, 3] + %0, %1:2 = transform.structured.tile %arg0 [2, 0, 3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) +} + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op + transform.structured.split %0#0 after %0#1 { dimension = 1 } : !transform.any_op, !transform.any_op } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -14,7 +14,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.dot"]} in %arg1 - %1, %loop = transform.structured.tile %0 [8000] + %1, %loop = transform.structured.tile %0 [8000] : (!pdl.operation) -> (!pdl.operation, !pdl.operation) } // CHECK-LABEL: func @dot @@ -38,7 +38,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 - %1, %loops:2 = transform.structured.tile %0 [5, 6] + %1, %loops:2 = transform.structured.tile %0 [5, 6] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) } // CHECK-LABEL: func @matvec @@ -65,10 +65,10 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] - %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] - %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] - %4, %loops_4:3 = transform.structured.tile %3 [2, 3, 4] + %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) + %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) + %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) + %4, %loops_4:3 = transform.structured.tile %3 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) } // CHECK-LABEL: func @matmul @@ -164,7 +164,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 - %1, %loops:2 = transform.structured.tile %0 [5, 6] {interchange = [1, 0]} + %1, %loops:2 = transform.structured.tile %0 [5, 6] {interchange = [1, 0]} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation) } // CHECK-LABEL: func @matvec_perm @@ -191,9 +191,9 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]} - %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]} - %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] + %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) + %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) + %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) } // CHECK-LABEL: func @matmul_perm diff --git a/mlir/test/Dialect/Transform/selective-targeting.mlir b/mlir/test/Dialect/Transform/selective-targeting.mlir --- a/mlir/test/Dialect/Transform/selective-targeting.mlir +++ b/mlir/test/Dialect/Transform/selective-targeting.mlir @@ -77,7 +77,7 @@ transform.sequence %arg0 : !pdl.operation failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target_attrA in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.tile %0 [4, 4, 4] + transform.structured.tile %0 [4, 4, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %1 = pdl_match @pdl_target_attrC in %arg1 : (!pdl.operation) -> !pdl.operation %2 = transform.get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %2 diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -16,7 +16,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op - %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] + %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %2 transform.bufferization.one_shot_bufferize %module_op diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -58,8 +58,10 @@ def testMultitileSizes(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): - structured.MultiTileSizesOp( - sequence.bodyTarget, dimension=1, target_size=42) + structured.MultiTileSizesOp(pdl.OperationType.get(), + sequence.bodyTarget, + dimension=1, + target_size=42) transform.YieldOp() # CHECK-LABEL: TEST: testMultitileSizes # CHECK: transform.sequence @@ -112,7 +114,10 @@ def testTileCompact(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): - structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) + structured.TileOp(pdl.OperationType.get(), + sequence.bodyTarget, + sizes=[4, 8], + interchange=[0, 1]) transform.YieldOp() # CHECK-LABEL: TEST: testTileCompact # CHECK: transform.sequence @@ -125,7 +130,10 @@ attr = DenseI64ArrayAttr.get([4, 8]) ichange = DenseI64ArrayAttr.get([0, 1]) with InsertionPoint(sequence.body): - structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) + structured.TileOp(pdl.OperationType.get(), + sequence.bodyTarget, + sizes=attr, + interchange=ichange) transform.YieldOp() # CHECK-LABEL: TEST: testTileAttributes # CHECK: transform.sequence @@ -136,8 +144,10 @@ def testTileZero(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) with InsertionPoint(sequence.body): - structured.TileOp( - sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) + structured.TileOp(pdl.OperationType.get(), + sequence.bodyTarget, + sizes=[4, 0, 2, 0], + interchange=[0, 1, 2, 3]) transform.YieldOp() # CHECK-LABEL: TEST: testTileZero # CHECK: transform.sequence @@ -153,7 +163,9 @@ with InsertionPoint(sequence.body): m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first") m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second") - structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) + structured.TileOp(pdl.OperationType.get(), + sequence.bodyTarget, + sizes=[m1, 3, m2, 0]) transform.YieldOp() # CHECK-LABEL: TEST: testTileDynamic # CHECK: %[[FIRST:.+]] = pdl_match