diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -842,22 +842,22 @@ "ArrayRef":$staticTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$mixedTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$staticNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$mixedNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)>, ]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1326,7 +1326,7 @@ Value target, ArrayRef staticTileSizes, transform::TileSizesSpec, - ArrayRef mapping) { + ArrayAttr mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), TileSizesSpec(), mapping); @@ -1335,7 +1335,7 @@ void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, transform::TileSizesSpec, - ArrayRef mapping) { + ArrayAttr mapping) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, @@ -1346,12 +1346,9 @@ MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); - ArrayAttr mappingAttr; - if (!mapping.empty()) - mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, /*numThreads=*/ValueRange{}, dynamicTileSizes, - /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mappingAttr); + /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping); } void transform::TileToForeachThreadOp::build(OpBuilder &builder, @@ -1359,7 +1356,7 @@ Value target, ArrayRef staticNumThreads, transform::NumThreadsSpec, - ArrayRef mapping) { + ArrayAttr mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), NumThreadsSpec(), mapping); @@ -1368,7 +1365,7 @@ void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedNumThreads, transform::NumThreadsSpec, - ArrayRef mapping) { + ArrayAttr mapping) { SmallVector staticNumThreads; SmallVector dynamicNumThreads; dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, @@ -1379,12 +1376,9 @@ MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); - ArrayAttr mappingAttr; - if (!mapping.empty()) - mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, - /*staticTileSizes=*/ArrayAttr(), mappingAttr); + /*staticTileSizes=*/ArrayAttr(), mapping); } DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(