diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -150,10 +150,19 @@ /// in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size); +/// Returns the value indicating a dynamic size in a shaped type. Prefer +/// mlirShapedTypeIsDynamicSize to direct comparisons with this value. +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(); + /// Checks whether the given value is used as a placeholder for dynamic strides /// and offsets in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); +/// Returns the value indicating a dynamic stride or offset in a shaped type. +/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with +/// this value. +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(); + //===----------------------------------------------------------------------===// // Vector type. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -25,6 +25,7 @@ class AffineApplyOp; class AffineBound; class AffineValueMap; +class IRRewriter; /// TODO: These should be renamed if they are on the mlir namespace. /// Ideally, they should go in a mlir::affine:: namespace. @@ -384,6 +385,12 @@ SmallVector applyMapToValues(OpBuilder &b, Location loc, AffineMap map, ValueRange values); +/// Returns the values obtained by applying `map` to the list of values, which +/// may be known constants. +SmallVector applyMapToValues(IRRewriter &b, Location loc, + AffineMap map, + ArrayRef values); + /// Given an affine map `map` and its input `operands`, this method composes /// into `map`, maps of AffineApplyOps whose results are the values in /// `operands`, iteratively until no more of `operands` are the result of an diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -153,6 +153,38 @@ }]; } +def SplitOp : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Indicates that the given `target` op should be split into two complementary + parts, which combined cover the entire iteration domain of the original op. + The split is performed along the iteration space dimension provided as + attribute. In case of dimension overflow, the transformation fails. The + split is performed at the dimension iterator value specified as either the + static split point attribute when it is known at transform IR construction + time or as the handle to an operation producing a single index-typed value + when it is computed by payload IR. In the latter case, the static split + point must be set to `ShapedType::kDynamicSize` and the dynamic size handle + must point to as many value-producing operations as there are structured + operations pointed to by the target handle. + + The operation consumes the target handle, but preserves the split point + handle if provided. It produces two new handles pointing to the two parts + of the structured op after splitting, in the same order as the target + operand, with the first handle corresponding to the part with lower + iteration space indices. + }]; + + let arguments = (ins PDL_Operation:$target, + I64Attr:$dimension, + Optional:$dynamic_split_point, + I64Attr:$static_split_point); + let results = (outs PDL_Operation:$first, PDL_Operation:$second); + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def SplitReductionOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -106,6 +106,34 @@ /// Patterns that are used to bubble up extract slice op above linalg op. void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); +/// Split the given `op` into two parts along the given iteration space +/// `dimension` at the specified `splitPoint`, and return the two parts. +/// +/// For example, the following op: +/// +/// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>) +/// outs(%2 : tensor<128x64xf32>) +/// +/// split along the first dimension at position 42 will result in: +/// +/// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1] +/// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1] +/// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>) +/// outs(%5 : tensor<42x64xf32>) +/// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1] +/// +/// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1] +/// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1] +/// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>) +/// outs(%8 : tensor<86x64xf32>) +/// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1] +/// +/// Note that there is no simplification other than constant propagation applied +/// to slice extraction and insertion. +std::pair splitOp(RewriterBase &rewriter, LinalgOp op, + unsigned dimension, + OpFoldResult splitPoint); + /// Perform standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` /// The permutation is expressed as a list of integers that specify diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -177,12 +177,18 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer); -/// Compute tile offsets, given a list of loop `ivs` and `tileSizes`. In case a +/// Creates either a memref.subview or a tensor.extract_slice with the given +/// offsets/sizes/strides based on the type of `value`. +Value createSlice(OpBuilder &builder, Location loc, Value value, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides); + +/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a /// tile size is zero (i.e., no tiling), the corresponding offset is also zero. SmallVector computeTileOffsets(OpBuilder &b, Location loc, ValueRange ivs, ValueRange tileSizes); -/// Compute tile sizes, given a list of `tileSizes` and dimension +/// Computes tile sizes, given a list of `tileSizes` and dimension /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the /// corresponding result size is the corresponding value from `sizeBounds`. /// Note: The returned tile sizes are closed intervals. @@ -190,6 +196,20 @@ ValueRange tileSizes, ArrayRef sizeBounds); +/// Returns the list of tensor output types produced when the given structured +/// operation `op` is applied to the given `operands`. Note that `operands` are +/// not necessarily the actual operands of `op`. +SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands); + +/// Creates `insert_slice` ops that insert `results` back into larger tensors +/// they were originally extracted from with `extract_slice` before being passed +/// as `operands` to the given structured operation `op` or its clone. Note that +/// `operands` are not necessarily the actual operands of `op`, the operation +/// serves only as metadata container for operand types and positions. +SmallVector insertSlicesBack(OpBuilder &builder, Location loc, + LinalgOp op, ValueRange operands, + ValueRange results); + /// Creates an extract_slice/subview op for a single `valueToTile` with /// `builder`. This new operation extracts a tile of `valueToTile`, starting /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck` diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -301,6 +301,15 @@ return shape; }, "Returns the shape of the ranked shaped type as a list of integers."); + c.def_static( + "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, + "Returns the value used to indicate dynamic dimensions in shaped " + "types."); + c.def_static( + "_get_dynamic_stride_or_offset", + []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, + "Returns the value used to indicate dynamic strides or offsets in " + "shaped types."); } private: diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -149,6 +149,8 @@ return unwrap(type).cast().getDimSize(static_cast(dim)); } +int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamicSize; } + bool mlirShapedTypeIsDynamicSize(int64_t size) { return ShapedType::isDynamic(size); } @@ -157,6 +159,10 @@ return ShapedType::isDynamicStrideOrOffset(val); } +int64_t mlirShapedTypeGetDynamicStrideOrOffset() { + return ShapedType::kDynamicStrideOrOffset; +} + //===----------------------------------------------------------------------===// // Vector type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -748,6 +748,76 @@ return res; } +SmallVector +mlir::applyMapToValues(IRRewriter &b, Location loc, AffineMap map, + ArrayRef values) { + // Materialize constants and keep track of produced operations so we can clean + // them up later. + SmallVector constants; + SmallVector actualValues; + actualValues.reserve(values.size()); + auto *dialect = b.getContext()->getLoadedDialect(); + for (OpFoldResult ofr : values) { + if (auto value = ofr.dyn_cast()) { + actualValues.push_back(value); + continue; + } + constants.push_back(dialect->materializeConstant(b, ofr.get(), + b.getIndexType(), loc)); + actualValues.push_back(constants.back()->getResult(0)); + } + + // Compose, fold and construct maps for each result independently because they + // may simplify more effectively. + SmallVector results; + results.reserve(map.getNumResults()); + bool foldedAll = true; + for (auto i : llvm::seq(0, map.getNumResults())) { + AffineMap submap = map.getSubMap({i}); + SmallVector operands = actualValues; + fullyComposeAffineMapAndOperands(&submap, &operands); + canonicalizeMapAndOperands(&submap, &operands); + + // Identify the constant operands and extract their values as attributes. + // Note that we cannot use the original values directly because the list of + // operands may have changed due to canonicalization and composition. + SmallVector constantOperands; + constantOperands.reserve(operands.size()); + for (Value operand : operands) { + IntegerAttr attr; + if (matchPattern(operand, m_Constant(&attr))) + constantOperands.push_back(attr); + else + constantOperands.push_back(nullptr); + } + + // Create an apply operation and immediately attempt to fold it. On sucess, + // delete the operation and prepare the (unmaterialized) value for being + // returned. On failure, return the function result. + // TODO: arguably, the main folder (createOrFold) API should support this + // use case instead of indiscriminately materializing constants. + auto apply = b.create(loc, submap, operands); + SmallVector foldResult; + if (succeeded(apply->fold(constantOperands, foldResult))) { + assert(foldResult.size() == 1 && "expected single-result map"); + b.eraseOp(apply); + results.push_back(foldResult.front()); + } else { + results.push_back(apply.getResult()); + foldedAll = false; + } + } + + // If the entire map could be folded, remove the constants that were used in + // the initial ops. + if (foldedAll) { + for (Operation *constant : constants) + b.eraseOp(constant); + } + + return results; +} + // A symbol may appear as a dim in affine.apply operations. This function // canonicalizes dims that are valid symbols into actual symbols. template diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -399,6 +399,161 @@ return result->op; } +//===----------------------------------------------------------------------===// +// SplitOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, + TransformState &state) { + // Collect the dynamic split points if provided. + ArrayRef payload = state.getPayloadOps(getTarget()); + SimpleRewriter rewriter(getContext()); + SmallVector splitPoints; + splitPoints.reserve(payload.size()); + if (getDynamicSplitPoint()) { + auto diag = DiagnosedSilenceableFailure::success(); + splitPoints = llvm::to_vector(llvm::map_range( + state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { + if (op->getNumResults() != 1 || + !op->getResult(0).getType().isIndex()) { + diag = emitSilenceableError() + << "expected dynamic split point handle to point to a " + "single-result index-typed op"; + diag.attachNote(op->getLoc()) << "dynamic split point"; + } + return OpFoldResult(op->getResult(0)); + })); + if (!diag.succeeded()) + return diag; + + if (splitPoints.size() != payload.size()) { + emitError() << "expected the dynamic split point handle to point to as " + "many operations (" + << splitPoints.size() << ") as the target handle (" + << payload.size() << ")"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + } else { + splitPoints.resize(payload.size(), + rewriter.getIndexAttr(getStaticSplitPoint())); + } + + // Split each target operation. + SmallVector first, second; + for (const auto &pair : llvm::zip(payload, splitPoints)) { + Operation *target = std::get<0>(pair); + auto linalgOp = dyn_cast(target); + if (!linalgOp) { + auto diag = emitSilenceableError() << "only applies to structured ops"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + if (getDimension() >= linalgOp.getNumLoops()) { + auto diag = emitSilenceableError() << "dimension " << getDimension() + << " does not exist in target op"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + rewriter.setInsertionPoint(linalgOp); + std::tie(first.emplace_back(), second.emplace_back()) = + linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); + } + + results.set(getFirst().cast(), first); + results.set(getSecond().cast(), second); + return DiagnosedSilenceableFailure::success(); +} + +void SplitOp::getEffects( + SmallVectorImpl &effects) { + // The target handle is consumed. + effects.emplace_back(MemoryEffects::Read::get(), getTarget(), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Free::get(), getTarget(), + TransformMappingResource::get()); + + // The dynamic split point handle is not consumed. + if (getDynamicSplitPoint()) { + effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(), + TransformMappingResource::get()); + } + + // The resulting handles are produced. + for (Value result : getResults()) { + effects.emplace_back(MemoryEffects::Allocate::get(), result, + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), result, + TransformMappingResource::get()); + } + + effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); +} + +ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; + IntegerAttr staticSplitPoint; + auto pdlOperationType = + pdl::OperationType::get(parser.getBuilder().getContext()); + if (parser.parseOperand(target) || + parser.resolveOperand(target, pdlOperationType, result.operands) || + parser.parseKeyword("after")) + return failure(); + + OptionalParseResult dynamicPointParseResult = + parser.parseOptionalOperand(dynamicSplitPoint); + if (!dynamicPointParseResult.hasValue()) { + int64_t staticSplitPointValue; + if (failed(parser.parseInteger(staticSplitPointValue))) + return failure(); + + staticSplitPoint = + parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); + } else { + if (failed(*dynamicPointParseResult) || + parser.resolveOperand(dynamicSplitPoint, pdlOperationType, + result.operands)) { + return failure(); + } + + staticSplitPoint = + parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); + } + + result.addAttribute( + SplitOp::getStaticSplitPointAttrName(result.name).getValue(), + staticSplitPoint); + if (failed(parser.parseOptionalAttrDict(result.attributes))) + return failure(); + + result.addTypes({pdlOperationType, pdlOperationType}); + return success(); +} + +void SplitOp::print(OpAsmPrinter &printer) { + printer << " " << getTarget() << " after "; + int64_t staticSplitSize = static_cast(getStaticSplitPoint()); + if (staticSplitSize != ShapedType::kDynamicSize) + printer << staticSplitSize; + else + printer << getDynamicSplitPoint(); + printer << " "; + printer.printOptionalAttrDict(getOperation()->getAttrs(), + {getStaticSplitPointAttrName()}); +} + +LogicalResult SplitOp::verify() { + if ((static_cast(getStaticSplitPoint()) != + ShapedType::kDynamicSize) ^ + (getDynamicSplitPoint() == nullptr)) { + return emitOpError() + << "expects either a dynamic or a static split point to be provided"; + } + return success(); +} + //===----------------------------------------------------------------------===// // SplitReductionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -22,6 +22,7 @@ NamedOpConversions.cpp Promotion.cpp SparseTensorRewriting.cpp + Split.cpp SplitReduction.cpp Tiling.cpp TilingInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -0,0 +1,158 @@ +//===- Split.cpp - Structured op splitting --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" + +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::linalg; + +/// Turns an OpFoldResult into a value, creating an index-typed constant if +/// necessary. +static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, + OpFoldResult opFoldResult) { + if (opFoldResult.is()) + return opFoldResult.get(); + auto attr = opFoldResult.get().cast(); + return builder.create(attr.getValue().getSExtValue()); +} + +/// Extract the slices of `operands` supplied to the given operation `op` such +/// that they are sufficient to execute the op for the subset of its iteration +/// space defined by `splitIterationSpace`. The subset is a part of the original +/// iteration space split at the given `dimension`. If `offset` is provided, it +/// indicates the iterator value at which the dimension has been split and +/// requires the "high" part starting at the given offset of the operands to be +/// generated; otherwise, the "low" part with no offset is generated. Note that +/// `operands` are not necessarily the actual operands of `op`. +static SmallVector +getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op, + ValueRange splitIterationSpace, ValueRange operands, + unsigned dimension, Value offset = nullptr) { + SmallVector slices; + slices.reserve(op.getNumInputsAndOutputs()); + for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + auto type = opOperand->get().getType().dyn_cast(); + AffineMap indexing = op.getTiedIndexingMap(opOperand); + + // If the type is not sliceable, or the slice is requested along the + // dimension that is not used in indexing this type, just use the entire + // operand. + if (!type || dimension >= indexing.getNumDims() || + !indexing.isFunctionOfDim(dimension)) { + slices.push_back(opOperand->get()); + continue; + } + + SmallVector sizes = + applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace); + SmallVector offsets(type.getRank(), builder.getIndexAttr(0)); + SmallVector strides(type.getRank(), builder.getIndexAttr(1)); + + if (offset) { + offsets[dimension] = offset; + IRRewriter rewriter(builder); + offsets = applyMapToValues(rewriter, builder.getLoc(), indexing, offsets); + } + + slices.push_back(createSlice(builder, op.getLoc(), + operands[opOperand->getOperandNumber()], + offsets, getAsOpFoldResult(sizes), strides)); + } + + return slices; +} + +/// Creates a part of the given `op` split along the iteration space `dimension` +/// with the given `size` and an optional `offset` (default 0). Makes slices +/// of operands, using the input operands of the original op and the output +/// operands provided as `resultOperands`. Expects `splitIterationSpace` to be +/// a list of values representing the shape of the iteration space of the +/// original op and updates it to be the iteration space of the curent part. +/// Returns the split-out op as well as the output operand values updated with +/// the partial results produced by this op through `results`. +static LinalgOp createSplitPart( + ImplicitLocOpBuilder &builder, LinalgOp op, ValueRange resultOperands, + llvm::MutableArrayRef splitIterationSpace, unsigned dimension, + Value size, SmallVectorImpl &results, Value offset = nullptr) { + splitIterationSpace[dimension] = size; + SmallVector operands = llvm::to_vector( + llvm::map_range(op.getInputOperands(), + [](OpOperand *opOperand) { return opOperand->get(); })); + llvm::append_range(operands, resultOperands); + operands = getOperandSlices(builder, op, splitIterationSpace, operands, + dimension, offset); + Operation *part = op.clone(builder, op.getLoc(), + getTensorOutputTypes(op, operands), operands); + results = insertSlicesBack(builder, builder.getLoc(), op, operands, + part->getResults()); + return cast(part); +} + +std::pair linalg::splitOp(RewriterBase &rewriter, + LinalgOp op, unsigned dimension, + OpFoldResult splitPoint) { + // Bail out on dimension overflow. + if (dimension >= op.getNumLoops()) + return std::make_pair(op, LinalgOp()); + + // Compute the iteration space size as values. + ImplicitLocOpBuilder builder(op.getLoc(), rewriter); + SmallVector allShapes = + op.createFlatListOfOperandDims(builder, op.getLoc()); + AffineMap shapesToLoops = op.getShapesToLoopsMap(); + SmallVector iterationSpaceShapes = + applyMapToValues(builder, op.getLoc(), shapesToLoops, allShapes); + + // Update the iteration space to have `splitPoint` as the size of `dimension` + // and use it to slice operands and results for a new, smaller instance of the + // `op`. Adjust the size if necessary to prevent overflows. Insert the partial + // results back. + Value splitPointValue = materializeOpFoldResult(builder, splitPoint); + splitPointValue = builder.createOrFold( + builder.getIndexType(), + AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()), + ValueRange({splitPointValue, iterationSpaceShapes[dimension]})); + SmallVector splitIterationSpace = + llvm::to_vector(iterationSpaceShapes); + SmallVector originalResults = llvm::to_vector( + llvm::map_range(op.getOutputOperands(), + [](OpOperand *opOperand) { return opOperand->get(); })); + SmallVector firstResults; + LinalgOp first = + createSplitPart(builder, op, originalResults, splitIterationSpace, + dimension, splitPointValue, firstResults); + + // Update the iteration space to cover the remaining part of the original + // space, then create another instance of the `op` in that space. The size of + // the remaining part may become zero, but is never negative because of the + // adjustment above. + AffineExpr d0 = builder.getAffineDimExpr(0); + AffineExpr d1 = builder.getAffineDimExpr(1); + SmallVector remainingSizes = applyMapToValues( + builder, op.getLoc(), AffineMap::inferFromExprList({d0 - d1}).front(), + {iterationSpaceShapes[dimension], splitPointValue}); + SmallVector secondResults; + LinalgOp second = + createSplitPart(builder, op, firstResults, splitIterationSpace, dimension, + remainingSizes.front(), secondResults, splitPointValue); + + // Fixup the linalg.index results in the second part. + SmallVector ivAdditions; + ivAdditions.resize(splitIterationSpace.size()); + ivAdditions[dimension] = splitPointValue; + linalg::addTileLoopIvsToIndexOpResults(builder, cast(second), + ivAdditions); + + // Replace the original op with the results of the two newly created ops. + rewriter.replaceOp(op, secondResults); + return std::make_pair(first, second); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -182,32 +182,11 @@ makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds, /*omitPartialTileCheck=*/false); - // TODO: use an interface/adaptor to avoid leaking position in - // `tiledOperands`. - SmallVector resultTensorTypes; - for (OpOperand *opOperand : op.getOutputTensorOperands()) - resultTensorTypes.push_back( - tiledOperands[opOperand->getOperandNumber()].getType()); - + SmallVector resultTensorTypes = + getTensorOutputTypes(op, tiledOperands); res = op.clone(b, loc, resultTensorTypes, tiledOperands); - - // Insert a insert_slice for each output tensor. - unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getOutputTensorOperands()) { - // TODO: use an interface/adaptor to avoid leaking position in - // `tiledOperands`. - Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; - // TODO: Propagate RewriterBase everywhere. - IRRewriter rewriter(b); - if (auto sliceOp = outputTensor.getDefiningOp()) { - tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp, - res->getResult(resultIdx), - sliceOp.getSource())); - } else { - tensorResults.push_back(res->getResult(resultIdx)); - } - ++resultIdx; - } + tensorResults = + insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); return scf::ValueVector(tensorResults.begin(), tensorResults.end()); }; GenerateLoopNest::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -913,6 +913,21 @@ return sliceOp->getResult(0); } +Value createSlice(OpBuilder &builder, Location loc, Value value, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides) { + if (value.getType().isa()) { + return builder.create(loc, value, offsets, sizes, + strides); + } + + // This intentionally does not attempt to compose the extractslice operations. + assert(value.getType().isa() && + "expected a ranked tensor type"); + return builder.create(loc, value, offsets, sizes, + strides); +} + SmallVector computeTileOffsets(OpBuilder &b, Location loc, ValueRange ivs, ValueRange tileSizes) { SmallVector offsets; @@ -943,6 +958,41 @@ return sizes; } +SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands) { + // TODO: use an interface/adaptor to avoid leaking position in + // `tiledOperands`. + return llvm::to_vector( + llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) { + return operands[opOperand->getOperandNumber()].getType(); + })); +} + +SmallVector insertSlicesBack(OpBuilder &builder, Location loc, + LinalgOp op, ValueRange operands, + ValueRange results) { + SmallVector tensorResults; + tensorResults.reserve(results.size()); + // Insert a insert_slice for each output tensor. + unsigned resultIdx = 0; + for (OpOperand *opOperand : op.getOutputTensorOperands()) { + // TODO: use an interface/adaptor to avoid leaking position in + // `tiledOperands`. + Value outputTensor = operands[opOperand->getOperandNumber()]; + if (auto sliceOp = outputTensor.getDefiningOp()) { + Value inserted = builder.create( + loc, sliceOp.source().getType(), results[resultIdx], sliceOp.source(), + sliceOp.offsets(), sliceOp.sizes(), sliceOp.strides(), + sliceOp.static_offsets(), sliceOp.static_sizes(), + sliceOp.static_strides()); + tensorResults.push_back(inserted); + } else { + tensorResults.push_back(results[resultIdx]); + } + ++resultIdx; + } + return tensorResults; +} + SmallVector makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, ArrayRef valuesToTile, diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -15,6 +15,12 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] +def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr: + if isinstance(value, int): + return IntegerAttr.get(IntegerType.get_signless(64), value) + return value + + def _get_array_attr( values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr: """Creates an array attribute from its operand.""" @@ -41,13 +47,7 @@ if isinstance(values, ArrayAttr): return values - attributes = [] - for value in values: - if isinstance(value, IntegerAttr): - attributes.append(value) - else: - attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value)) - return ArrayAttr.get(attributes) + return ArrayAttr.get([_get_int64_attr(v) for v in values]) def _get_int_int_array_attr( @@ -152,6 +152,39 @@ pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip) +class SplitOp: + """Specialization for SplitOp class.""" + + def __init__(self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + split_point: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None): + dimension = _get_int64_attr(dimension) + if isinstance(split_point, int): + split_point = _get_int64_attr(split_point) + + if isinstance(split_point, Attribute): + static_split_point = split_point + dynamic_split_point = None + else: + static_split_point = _get_int64_attr(ShapedType._get_dynamic_size()) + dynamic_split_point = _get_op_result_or_value(split_point) + + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, + pdl_operation_type, + _get_op_result_or_value(target), + dimension=dimension, + static_split_point=static_split_point, + dynamic_split_point=dynamic_split_point, + loc=loc, + ip=ip) + + class TileOp: """Specialization for TileOp class.""" diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir @@ -0,0 +1,366 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -verify-diagnostics | FileCheck %s + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1:2 = transform.structured.split %0 after 42 { dimension = 0 } + } +} + +func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 + +// CHECK: #[[$ADD_42_MAP:.+]] = affine_map<(d0) -> (d0 + 42)> +// CHECK: #[[$ADD_10_MAP:.+]] = affine_map<(d0) -> (d0 + 10)> + +// CHECK-LABEL: @one_d_static +// CHECK-SAME: %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32> +func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [42] [1] : tensor<100xf32> to tensor<42xf32> + // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [42] [1] : tensor<100xf32> to tensor<42xf32> + // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_LOW]] + // CHECK: outs(%[[OUT_SLICE_LOW]] + // CHECK: linalg.index 0 + // CHECK: func.call @elem + // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [42] [1] + // + // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][42] [58] [1] : tensor<100xf32> to tensor<58xf32> + // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][42] [58] [1] : tensor<100xf32> to tensor<58xf32> + // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_HIGH]] + // CHECK: outs(%[[OUT_SLICE_HIGH]] + // CHECK: %[[IDX:.+]] = linalg.index 0 + // CHECK: affine.apply #[[$ADD_42_MAP]](%[[IDX]]) + // CHECK: func.call @elem + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][42] [58] [1] + %0 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<100xf32> + + // CHECK: return %[[RES]] + return %0 : tensor<100xf32> +} + +// CHECK-LABEL: @one_d_static_overflow +// CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> +func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { + // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> + // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> + // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_LOW]] + // CHECK: outs(%[[OUT_SLICE_LOW]] + // CHECK: linalg.index 0 + // CHECK: func.call @elem + // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1] + // + // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> + // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> + // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_HIGH]] + // CHECK: outs(%[[OUT_SLICE_HIGH]] + // CHECK: %[[IDX:.+]] = linalg.index 0 + // CHECK: affine.apply #[[$ADD_10_MAP]](%[[IDX]]) + // CHECK: func.call @elem + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][10] [0] [1] + %0 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<10xf32> + return %0 : tensor<10xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @func_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1 = transform.pdl_match @func_call in %arg1 + transform.structured.split %0 after %1 { dimension = 0 } + } +} + +func.func private @get_size() -> index + +// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<(d0, d1) -> (d0, 100)> +// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)> + +// CHECK-LABEL: @dynamic +func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // CHECK: %[[SPLIT:.+]] = call @get_size + // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]](%[[SPLIT]] + // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor + // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor + // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_LOW]] + // CHECK: outs(%[[OUT_SLICE_LOW]] + // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1] + // + // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] + // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] + // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor + // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] + // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor + // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_HIGH]] + // CHECK: outs(%[[OUT_SLICE_HIGH]] + // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] + %0 = func.call @get_size() : () -> index + %1 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%3: f32, %4: f32): + linalg.yield %3 : f32 + } -> tensor<100xf32> + return %1 : tensor<100xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1:2 = transform.structured.split %0 after 4 { dimension = 0} + %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } + } +} + +func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 + +// CHECK-LABEL: @two_d +func.func @two_d(%arg0: tensor<10x34xf32>, + %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> { + // Check the overall structure: split along the dimension 0, and then split + // the second half only along the dimension 1. + // CHECK: %[[IN_1:.+]] = tensor.extract_slice %[[IN:.+]][0, 0] + // CHECK: %[[OUT_1:.+]] = tensor.extract_slice %[[OUT:.+]][0, 0] + // CHECK: %[[RES_1:.+]] = linalg.generic + // CHECK-SAME: ins(%[[IN_1]] : tensor<4x34xf32>) + // CHECK-SAME: outs(%[[OUT_1]] : tensor<4x34xf32>) + // CHECK: %[[PARTIAL_1:.+]] = tensor.insert_slice %[[RES_1]] into %[[OUT]] + // + // CHECK: %[[IN_2:.+]] = tensor.extract_slice %[[IN]] + // CHECK: %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]] + // CHECK: %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]] + // CHECK: %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]] + // CHECK: %[[RES_21:.+]] = linalg.generic + // CHECK-SAME: ins(%[[IN_21]] : tensor<6x16xf32>) + // CHECK-SAME: outs(%[[OUT_21]] : tensor<6x16xf32>) + // CHECK: %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]] + // + // CHECK: %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]] + // CHECK: %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]] + // CHECK: %[[RES_22:.+]] = linalg.generic + // CHECK-SAME: ins(%[[IN_22]] : tensor<6x18xf32>) + // CHECK-SAME: outs(%[[OUT_22]] : tensor<6x18xf32>) + // CHECK: %[[PARTIAL_22:.+]] = tensor.insert_slice %[[RES_22]] into %[[PARTIAL_21]] + // CHECK: %[[PARTIAL_2:.+]] = tensor.insert_slice %[[PARTIAL_22]] into %[[PARTIAL_1]] + %0 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"] + } + ins(%arg0: tensor<10x34xf32>) + outs(%arg1: tensor<10x34xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<10x34xf32> + return %0 : tensor<10x34xf32> +} + +// ----- + +transform.sequence { +^bb1(%arg1: !pdl.operation): + // expected-error @below {{expects either a dynamic or a static split point to be provided}} + %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -1 } : (!pdl.operation) -> (!pdl.operation, !pdl.operation) +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @func_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1 = transform.pdl_match @func_call in %arg1 + // expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}} + transform.structured.split %0 after %1 { dimension = 0 } + } +} + +func.func private @get_size() -> i64 + +func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-note @below {{dynamic split point}} + %0 = func.call @get_size() : () -> i64 + %1 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%3: f32, %4: f32): + linalg.yield %3 : f32 + } -> tensor<100xf32> + return %1 : tensor<100xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @func_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1 = transform.pdl_match @func_call in %arg1 + // expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}} + transform.structured.split %0 after %1 { dimension = 0 } + } +} + +func.func private @get_size() -> i64 + +func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + %1 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%3: f32, %4: f32): + linalg.yield %3 : f32 + } -> tensor<100xf32> + return %1 : tensor<100xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @func_return : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.return"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @func_return in %arg1 + // expected-error @below {{only applies to structured ops}} + transform.structured.split %0 after 16 { dimension = 1 } + } +} + +func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-note @below {{target op}} + return %arg0 : tensor<100xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + // expected-error @below {{dimension 1 does not exist in target op}} + transform.structured.split %0 after 16 { dimension = 1 } + } +} + +func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-note @below {{target op}} + %0 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%0: f32, %1: f32): + linalg.yield %0 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -84,6 +84,19 @@ # CHECK: transform.structured.scalarize +@run +def testSplit(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42) + structured.SplitOp( + split.results[0], dimension=3, split_point=split.results[1]) + transform.YieldOp() + # CHECK-LABEL: TEST: testSplit + # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 + # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 + + @run def testTileCompact(): sequence = transform.SequenceOp()