diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1697,21 +1697,29 @@ OpBuilder<(ins "TypeRange":$loopTypes, "Value":$target, "ArrayRef":$staticTileSizes, - CArg<"ArrayRef", "{}">:$interchange)>, + CArg<"ArrayRef", "{}">:$interchange, + CArg<"std::optional>", "std::nullopt">: + $scalableSizes)>, OpBuilder<(ins "TypeRange":$loopTypes, "Value":$target, "ArrayRef":$mixedTileSizes, - CArg<"ArrayRef", "{}">:$interchange)>, + CArg<"ArrayRef", "{}">:$interchange, + CArg<"std::optional>", "std::nullopt">: + $scalableSizes)>, OpBuilder<(ins "Value":$target, "ArrayRef":$staticTileSizes, - CArg<"ArrayRef", "{}">:$interchange)>, + CArg<"ArrayRef", "{}">:$interchange, + CArg<"std::optional>", "std::nullopt">: + $scalableSizes)>, OpBuilder<(ins "Value":$target, "ArrayRef":$mixedTileSizes, - CArg<"ArrayRef", "{}">:$interchange)> - + CArg<"ArrayRef", "{}">:$interchange, + CArg<"std::optional>", "std::nullopt">: + $scalableSizes)>, ]; let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ /// Returns the list of tile sizes, which may be static (Attribute) or diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2332,36 +2332,41 @@ void transform::TileOp::build(OpBuilder &builder, OperationState &result, TypeRange loopTypes, Value target, ArrayRef staticTileSizes, - ArrayRef interchange) { + ArrayRef interchange, + std::optional> scalableSizes) { return build(builder, result, loopTypes, /*target=*/target, /*mixedTileSizes=*/ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), - interchange); + interchange, scalableSizes); } void transform::TileOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef staticTileSizes, - ArrayRef interchange) { + ArrayRef interchange, + std::optional> scalableSizes) { build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), - interchange); + interchange, scalableSizes); } void transform::TileOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, - ArrayRef interchange) { + ArrayRef interchange, + std::optional> scalableSizes) { // Loop types are automaticaly splat by the callee, setting up one is // enough. SmallVector loopTypes(1, builder.getType()); - build(builder, result, loopTypes, target, mixedTileSizes, interchange); + build(builder, result, loopTypes, target, mixedTileSizes, interchange, + scalableSizes); } void transform::TileOp::build(OpBuilder &builder, OperationState &result, TypeRange loopTypes, Value target, ArrayRef mixedTileSizes, - ArrayRef interchange) { + ArrayRef interchange, + std::optional> scalableSizes) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); @@ -2379,12 +2384,24 @@ resultTypes.append(numExpectedLoops, loopTypes[0]); else llvm::append_range(resultTypes, loopTypes); + SmallVector expandedScalableSizes(mixedTileSizes.size(), false); + if (scalableSizes.has_value()) + expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end()); build(builder, result, /*tiled_linalg_op=*/target.getType(), /*loops=*/resultTypes, /*target=*/target, /*dynamic_sizes=*/dynamicTileSizes, /*static_sizes=*/staticTileSizesAttr, - /*interchange=*/builder.getDenseI64ArrayAttr(interchange)); + /*interchange=*/builder.getDenseI64ArrayAttr(interchange), + /*scalable_sizes=*/expandedScalableSizes); +} + +LogicalResult transform::TileOp::verify() { + if (getMixedSizes().size() != getScalableSizes().size()) + return emitOpError("expected same number of sizes (") + << getMixedSizes().size() << ") and scalable sizes ()" + << getScalableSizes().size() << ")"; + return success(); } DiagnosedSilenceableFailure