diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -721,13 +721,19 @@ ``` }]; + // TODO: support mixed static-dynamic (see TileToForeachThreadOp). let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$tile_sizes); + DefaultValuedAttr:$tile_sizes); let results = (outs PDL_Operation:$for_op, PDL_Operation:$fill_op, PDL_Operation:$split_linalg_op, PDL_Operation:$combining_linalg_op); + let builders = [ + OpBuilder<(ins "Value":$target, + "ArrayRef":$staticTileSizes)> + ]; + let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ @@ -808,15 +814,23 @@ ``` }]; + // TODO: support mixed static-dynamic (see TileToForeachThreadOp). let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$num_threads, - DefaultValuedAttr:$tile_sizes, + DefaultValuedAttr:$num_threads, + DefaultValuedAttr:$tile_sizes, OptionalAttr:$mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$fill_op, PDL_Operation:$split_linalg_op, PDL_Operation:$combining_linalg_op); + let builders = [ + OpBuilder<(ins "Value":$target, + "ArrayRef":$staticNumThreads, + "ArrayRef":$staticTileSizes, + CArg<"ArrayAttr", "{}">:$mapping)> + ]; + let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ @@ -825,6 +839,7 @@ ::llvm::SmallVectorImpl<::mlir::Operation *> &results, ::mlir::transform::TransformState &state); }]; + } def TileOp : Op staticTileSizes) { + // Call the default builder. + // This is future-proof re mixed static-dynamic and setting up the proper + // operands segment sizes attributes for multiple variadic operands. + // In the absence of this, horrible bugs ensue. + // TODO: support mixed static-dynamic (see TileToForeachThreadOp). + MLIRContext *ctx = builder.getContext(); + auto opTy = pdl::OperationType::get(ctx); + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); + build(builder, result, + /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, + /*target=*/target, + /*tile_sizes=*/staticTileSizesAttr); +} + DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); - SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); - SmallVector sizes; - for (int64_t size : tileSizes) { - sizes.push_back(rewriter.getIndexAttr(size)); - } - FailureOr result = scf::tileReductionUsingScf( rewriter, cast(target.getOperation()), - sizes); + getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()))); if (failed(result)) return emitDefaultSilenceableFailure(target); @@ -1227,14 +1238,37 @@ // TileReductionUsingForeachThreadOp //===----------------------------------------------------------------------===// +void transform::TileReductionUsingForeachThreadOp::build( + OpBuilder &builder, OperationState &result, Value target, + ArrayRef staticNumThreads, ArrayRef staticTileSizes, + ArrayAttr mapping) { + // Call the default builder. + // This is future-proof re mixed static-dynamic and setting up the proper + // operands segment sizes attributes for multiple variadic operands. + // In the absence of this, horrible bugs ensue. + // TODO: support mixed static-dynamic (see TileToForeachThreadOp). + MLIRContext *ctx = builder.getContext(); + auto opTy = pdl::OperationType::get(ctx); + auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); + build(builder, result, + /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, + /*target=*/target, + /*num_threads=*/staticNumThreadsAttr, + /*tile_sizes=*/staticTileSizesAttr, + /*mapping=*/mapping); +} + DiagnosedSilenceableFailure transform::TileReductionUsingForeachThreadOp::applyToOne( linalg::LinalgOp target, SmallVectorImpl &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); - SmallVector numThreads = getAsOpFoldResult(getNumThreads()); - SmallVector tileSizes = getAsOpFoldResult(getTileSizes()); + SmallVector numThreads = + getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); + SmallVector tileSizes = + getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())); FailureOr result = linalg::tileReductionUsingForeachThread( rewriter, cast(target.getOperation()), diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -17,7 +17,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 5] } + %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = array } } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> @@ -71,7 +71,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [5, 0] } + %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = array } } // CHECK: func @reduction_tile_transpose @@ -107,7 +107,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5] } + %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = array } } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)> @@ -159,7 +159,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 0, 5] } + %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = array } } // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)> @@ -219,7 +219,7 @@ ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 - { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread] } + { num_threads = array, tile_sizes = array, mapping = [#gpu.thread] } } // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)> @@ -285,7 +285,7 @@ ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 - { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread] } + { num_threads = array, tile_sizes = array, mapping = [#gpu.thread] } // CHECK: expecting fill // CHECK-NEXT: linalg.fill