diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -20,6 +20,12 @@ class GenericOp; class LinalgOp; } // namespace linalg + +namespace transform { +// Types needed for builders. +struct TileSizesSpec {}; +struct NumThreadsSpec {}; +} // namespace transform } // namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -112,6 +112,10 @@ [TransformMappingAlloc, TransformMappingWrite]>:$fused_op); let assemblyFormat = "$producer_op `into` $containing_op attr-dict"; + + let builders = [ + OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)> + ]; } def GeneralizeOp : Op":$opNames)> + ]; + let assemblyFormat = [{ (`ops` `{` $ops^ `}`)? (`interface` `{` $interface^ `}`)? @@ -600,6 +608,15 @@ let assemblyFormat = "$target attr-dict"; + let builders = [ + OpBuilder<(ins "Value":$target, + "int64_t":$splitFactor, + "int64_t":$insertSplitDimension, + CArg<"bool", "false">:$innerParallel, + CArg<"bool", "false">:$useScalingAlgorithm, + CArg<"bool", "false">:$useAlloc)> + ]; + let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, @@ -818,6 +835,30 @@ OptionalAttr:$thread_dim_mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); + + let builders = [ + OpBuilder<(ins "Value":$target, + "ArrayRef":$staticTileSizes, + CArg<"::mlir::transform::TileSizesSpec", + "::mlir::transform::TileSizesSpec()">, + CArg<"ArrayRef", "{}">:$threadDimMapping)>, + OpBuilder<(ins "Value":$target, + "ArrayRef":$mixedTileSizes, + CArg<"::mlir::transform::TileSizesSpec", + "::mlir::transform::TileSizesSpec()">, + CArg<"ArrayRef", "{}">:$threadDimMapping)>, + OpBuilder<(ins "Value":$target, + "ArrayRef":$staticNumThreads, + CArg<"::mlir::transform::NumThreadsSpec", + "::mlir::transform::NumThreadsSpec()">, + CArg<"ArrayRef", "{}">:$threadDimMapping)>, + OpBuilder<(ins "Value":$target, + "ArrayRef":$mixedNumThreads, + CArg<"::mlir::transform::NumThreadsSpec", + "::mlir::transform::NumThreadsSpec()">, + CArg<"ArrayRef", "{}">:$threadDimMapping)>, + ]; + let assemblyFormat = [{ $target oilist( `num_threads` custom($num_threads, @@ -943,6 +984,10 @@ let results = (outs PDL_Operation:$transformed); let assemblyFormat = "$target attr-dict"; + + let builders = [ + OpBuilder<(ins "Value":$target, CArg<"bool", "false">:$vectorizePadding)> + ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -253,6 +253,11 @@ let arguments = (ins TransformTypeInterface:$handle, I64Attr:$num_result_handles); let results = (outs Variadic:$results); + + let builders = [ + OpBuilder<(ins "Value":$handle, "int64_t":$numResultHandles)> + ]; + let assemblyFormat = [{ $handle `in` `[` $num_result_handles `]` attr-dict `:` functional-type(operands, results) @@ -305,6 +310,12 @@ let arguments = (ins Optional:$target, OptionalAttr:$name); let results = (outs); + + let builders = [ + OpBuilder<(ins CArg<"StringRef", "StringRef()">:$name)>, + OpBuilder<(ins "Value":$target, CArg<"StringRef", "StringRef()">:$name)> + ]; + let assemblyFormat = "$target attr-dict (`:` type($target)^)?"; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -254,6 +254,14 @@ // FuseIntoContainingOp //===----------------------------------------------------------------------===// +void transform::FuseIntoContainingOp::build(OpBuilder &builder, + OperationState &result, + Value producerOp, + Value containingOp) { + result.addOperands({producerOp, containingOp}); + result.addTypes(pdl::OperationType::get(builder.getContext())); +} + /// Find the first "extract" user of `producerOp` and tile it right before its /// use. The tiled op is fused under the `containingOp`. /// Return this fused op on success or nullptr if anything fails. @@ -628,6 +636,14 @@ // MatchOp //===---------------------------------------------------------------------===// +void transform::MatchOp::build(OpBuilder &builder, OperationState &result, + Value target, ArrayRef opNames) { + result.addOperands(target); + result.addAttribute(MatchOp::getOpsAttrName(result.name), + builder.getStrArrayAttr(opNames)); + result.addTypes(pdl::OperationType::get(builder.getContext())); +} + DiagnosedSilenceableFailure transform::MatchOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -1069,6 +1085,34 @@ // SplitReductionOp //===----------------------------------------------------------------------===// +void transform::SplitReductionOp::build( + OpBuilder &builder, OperationState &result, Value target, + int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel, + bool useScalingAlgorithm, bool useAlloc) { + MLIRContext *ctx = builder.getContext(); + result.addOperands(target); + result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name), + builder.getI64IntegerAttr(splitFactor)); + result.addAttribute( + SplitReductionOp::getInsertSplitDimensionAttrName(result.name), + builder.getI64IntegerAttr(insertSplitDimension)); + if (innerParallel) { + result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name), + builder.getUnitAttr()); + } + if (useScalingAlgorithm) { + result.addAttribute( + SplitReductionOp::getUseScalingAlgorithmAttrName(result.name), + builder.getUnitAttr()); + } + if (useAlloc) { + result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name), + builder.getUnitAttr()); + } + auto resultType = pdl::OperationType::get(ctx); + result.addTypes({resultType, resultType, resultType, resultType}); +} + DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl &results, @@ -1277,13 +1321,75 @@ // TileToForeachThreadOp //===----------------------------------------------------------------------===// +void transform::TileToForeachThreadOp::build( + OpBuilder &builder, OperationState &result, Value target, + ArrayRef staticTileSizes, transform::TileSizesSpec, + ArrayRef threadDimMapping) { + return build(builder, result, target, + getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), + TileSizesSpec(), threadDimMapping); +} + +void transform::TileToForeachThreadOp::build( + OpBuilder &builder, OperationState &result, Value target, + ArrayRef mixedTileSizes, transform::TileSizesSpec, + ArrayRef threadDimMapping) { + SmallVector staticTileSizes; + SmallVector dynamicTileSizes; + dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, + ShapedType::kDynamicSize); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, horrible + // bugs ensue. + MLIRContext *ctx = builder.getContext(); + auto operationType = pdl::OperationType::get(ctx); + auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); + ArrayAttr threadDimMappingAttr; + if (!threadDimMapping.empty()) + threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); + build(builder, result, TypeRange{operationType, operationType}, target, + /*numThreads=*/ValueRange{}, dynamicTileSizes, + /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, + threadDimMappingAttr); +} + +void transform::TileToForeachThreadOp::build( + OpBuilder &builder, OperationState &result, Value target, + ArrayRef staticNumThreads, transform::NumThreadsSpec, + ArrayRef threadDimMapping) { + return build(builder, result, target, + getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), + NumThreadsSpec(), threadDimMapping); +} + +void transform::TileToForeachThreadOp::build( + OpBuilder &builder, OperationState &result, Value target, + ArrayRef mixedNumThreads, transform::NumThreadsSpec, + ArrayRef threadDimMapping) { + SmallVector staticNumThreads; + SmallVector dynamicNumThreads; + dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, + staticNumThreads, ShapedType::kDynamicSize); + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, horrible + // bugs ensue. + MLIRContext *ctx = builder.getContext(); + auto operationType = pdl::OperationType::get(ctx); + auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); + ArrayAttr threadDimMappingAttr; + if (!threadDimMapping.empty()) + threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); + build(builder, result, TypeRange{operationType, operationType}, target, + dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, + /*staticTileSizes=*/ArrayAttr(), threadDimMappingAttr); +} + DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef targets, ArrayRef mixedNumThreads, ArrayRef mixedTileSizes, Optional threadDimMapping, SmallVector &tileOps, SmallVector &tiledOps) { - if (targets.empty()) return DiagnosedSilenceableFailure(success()); @@ -1573,6 +1679,16 @@ // VectorizeOp //===----------------------------------------------------------------------===// +void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result, + Value target, bool vectorizePadding) { + result.addOperands(target); + if (vectorizePadding) { + result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name), + builder.getUnitAttr()); + } + result.addTypes(pdl::OperationType::get(builder.getContext())); +} + namespace { /// This is an helper only to call vectorize via a pattern inside of /// VectorizeOp::applyToOne. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -314,11 +314,11 @@ void transform::TransformResults::set(OpResult value, ArrayRef ops) { - unsigned position = value.getResultNumber(); - assert(position < segments.size() && + int64_t position = value.getResultNumber(); + assert(position < static_cast(segments.size()) && "setting results for a non-existent handle"); assert(segments[position].data() == nullptr && "results already set"); - unsigned start = operations.size(); + int64_t start = operations.size(); llvm::append_range(operations, ops); segments[position] = makeArrayRef(operations).drop_front(start); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -472,6 +472,16 @@ // SplitHandlesOp //===----------------------------------------------------------------------===// +void transform::SplitHandlesOp::build(OpBuilder &builder, + OperationState &result, Value target, + int64_t numResultHandles) { + result.addOperands(target); + result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name), + builder.getI64IntegerAttr(numResultHandles)); + auto pdlOpType = pdl::OperationType::get(builder.getContext()); + result.addTypes(SmallVector(numResultHandles, pdlOpType)); +} + DiagnosedSilenceableFailure transform::SplitHandlesOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -812,6 +822,20 @@ // PrintOp //===----------------------------------------------------------------------===// +void transform::PrintOp::build(OpBuilder &builder, OperationState &result, + StringRef name) { + if (!name.empty()) { + result.addAttribute(PrintOp::getNameAttrName(result.name), + builder.getStrArrayAttr(name)); + } +} + +void transform::PrintOp::build(OpBuilder &builder, OperationState &result, + Value target, StringRef name) { + result.addOperands({target}); + build(builder, result, name); +} + DiagnosedSilenceableFailure transform::PrintOp::apply(transform::TransformResults &results, transform::TransformState &state) {