diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -150,10 +150,19 @@ /// in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size); +/// Returns the value indicating a dynamic size in a shaped type. Prefer +/// mlirShapedTypeIsDynamicSize to direct comparisons with this value. +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(); + /// Checks whether the given value is used as a placeholder for dynamic strides /// and offsets in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); +/// Returns the value indicating a dynamic stride or offset in a shaped type. +/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with +/// this value. +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(); + //===----------------------------------------------------------------------===// // Vector type. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -153,6 +153,38 @@ }]; } +def SplitOp : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Indicates that the given `target` op should be split into two parts, which + combined cover the entire iteration domain of the original op. The split + is performed along the iteration space dimension provided as attribute. In + case of dimension overflow, the transformation fails. The split is performed + at the dimension iterator value specified as either the static split point + attribute when it is known at transform IR construction time or as the + handle to an operation producing a single index-typed value when it is + computed by payload IR. In the latter case, the static split point must be + set to `ShapedType::kDynamicSize` and the dynamic size handle must point to + as many value-producing operations as there are structured operations + pointed to by the target handle. + + The operation consumes the target handle, but preserves the split point + handle if provided. It produces two new handles pointing to the two parts + of the structured op after splitting, in the same order as the target + operand, with the first handle corresponding to the part with lower + iteration space indices. + }]; + + let arguments = (ins PDL_Operation:$target, + I64Attr:$dimension, + Optional:$dynamic_split_point, + I64Attr:$static_split_point); + let results = (outs PDL_Operation:$first, PDL_Operation:$second); + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def SplitReductionOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -106,6 +106,34 @@ /// Patterns that are used to bubble up extract slice op above linalg op. void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); +/// Split the given `op` into two parts along the given iteration space +/// `dimension` at the specified `splitPoint`, and return the two parts. +/// +/// For example, the following op: +/// +/// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>) +/// outs(%2 : tensor<128x64xf32>) +/// +/// split along the first dimension at position 42 will result in: +/// +/// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1] +/// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1] +/// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>) +/// outs(%5 : tensor<42x64xf32>) +/// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1] +/// +/// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1] +/// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1] +/// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>) +/// outs(%8 : tensor<86x64xf32>) +/// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1] +/// +/// Note that there is no simplification other than constant propagation applied +/// to slice extraction and insertion. +std::pair splitOp(RewriterBase &rewriter, LinalgOp op, + unsigned dimension, + OpFoldResult splitPoint); + /// Perform standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` /// The permutation is expressed as a list of integers that specify diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -168,12 +168,18 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer); -/// Compute tile offsets, given a list of loop `ivs` and `tileSizes`. In case a +/// Creates either a memref.subview or a tensor.extract_slice with the given +/// offsets/sizes/strides based on the type of `value`. +Value createSlice(OpBuilder &builder, Location loc, Value value, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides); + +/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a /// tile size is zero (i.e., no tiling), the corresponding offset is also zero. SmallVector computeTileOffsets(OpBuilder &b, Location loc, ValueRange ivs, ValueRange tileSizes); -/// Compute tile sizes, given a list of `tileSizes` and dimension +/// Computes tile sizes, given a list of `tileSizes` and dimension /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the /// corresponding result size is the corresponding value from `sizeBounds`. /// Note: The returned tile sizes are closed intervals. @@ -181,6 +187,20 @@ ValueRange tileSizes, ArrayRef sizeBounds); +/// Returns the list of tensor output types produced when the given structured +/// operation `op` is applied to the given `operands`. Note that `operands` are +/// not necessarily the actual operands of `op`. +SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands); + +/// Creates `insert_slice` ops that insert `results` back into larger tensors +/// they were originally extracted from with `extract_slice` before being passed +/// as `operands` to the given structured operation `op` or its clone. Note that +/// `operands` are not necessarily the actual operands of `op`, the operation +/// serves only as metadata container for operand types and positions. +SmallVector insertSlicesBack(OpBuilder &builder, Location loc, + LinalgOp op, ValueRange operands, + ValueRange results); + /// Creates an extract_slice/subview op for a single `valueToTile` with /// `builder`. This new operation extracts a tile of `valueToTile`, starting /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck` diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -301,6 +301,15 @@ return shape; }, "Returns the shape of the ranked shaped type as a list of integers."); + c.def_static( + "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, + "Returns the value used to indicate dynamic dimensions in shaped " + "types."); + c.def_static( + "_get_dynamic_stride_or_offset", + []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, + "Returns the value used to indicate dynamic strides or offsets in " + "shaped types."); } private: diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -149,6 +149,8 @@ return unwrap(type).cast().getDimSize(static_cast(dim)); } +int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamicSize; } + bool mlirShapedTypeIsDynamicSize(int64_t size) { return ShapedType::isDynamic(size); } @@ -157,6 +159,10 @@ return ShapedType::isDynamicStrideOrOffset(val); } +int64_t mlirShapedTypeGetDynamicStrideOrOffset() { + return ShapedType::kDynamicStrideOrOffset; +} + //===----------------------------------------------------------------------===// // Vector type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -399,6 +399,161 @@ return result->op; } +//===----------------------------------------------------------------------===// +// SplitOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, + TransformState &state) { + // Collect the dynamic split points if provided. + ArrayRef payload = state.getPayloadOps(getTarget()); + SimpleRewriter rewriter(getContext()); + SmallVector splitPoints; + splitPoints.reserve(payload.size()); + if (getDynamicSplitPoint()) { + auto diag = DiagnosedSilenceableFailure::success(); + splitPoints = llvm::to_vector(llvm::map_range( + state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { + if (op->getNumResults() != 1 || + !op->getResult(0).getType().isIndex()) { + diag = emitSilenceableError() + << "expected dynamic split point handle to point to a " + "single-result index-typed op"; + diag.attachNote(op->getLoc()) << "dynamic split point"; + } + return OpFoldResult(op->getResult(0)); + })); + if (!diag.succeeded()) + return diag; + + if (splitPoints.size() != payload.size()) { + emitError() << "expected the dynamic split point handle to point to as " + "many operations (" + << splitPoints.size() << ") as the target handle (" + << payload.size() << ")"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + } else { + splitPoints.resize(payload.size(), + rewriter.getIndexAttr(getStaticSplitPoint())); + } + + // Split each target operation. + SmallVector first, second; + for (const auto &pair : llvm::zip(payload, splitPoints)) { + Operation *target = std::get<0>(pair); + auto linalgOp = dyn_cast(target); + if (!linalgOp) { + auto diag = emitSilenceableError() << "only applies to structured ops"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + if (getDimension() >= linalgOp.getNumLoops()) { + auto diag = emitSilenceableError() << "dimension " << getDimension() + << " does not exist is target op"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + rewriter.setInsertionPoint(linalgOp); + std::tie(first.emplace_back(), second.emplace_back()) = + linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); + } + + results.set(getFirst().cast(), first); + results.set(getSecond().cast(), second); + return DiagnosedSilenceableFailure::success(); +} + +void SplitOp::getEffects( + SmallVectorImpl &effects) { + // The target handle is consumed. + effects.emplace_back(MemoryEffects::Read::get(), getTarget(), + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Free::get(), getTarget(), + TransformMappingResource::get()); + + // The dynamic split point handle is not consumed. + if (getDynamicSplitPoint()) { + effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(), + TransformMappingResource::get()); + } + + // The resulting handles are produced. + for (Value result : getResults()) { + effects.emplace_back(MemoryEffects::Allocate::get(), result, + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), result, + TransformMappingResource::get()); + } + + effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); +} + +ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; + IntegerAttr staticSplitPoint; + auto pdlOperationType = + pdl::OperationType::get(parser.getBuilder().getContext()); + if (parser.parseOperand(target) || + parser.resolveOperand(target, pdlOperationType, result.operands) || + parser.parseKeyword("after")) + return failure(); + + OptionalParseResult dynamicPointParseResult = + parser.parseOptionalOperand(dynamicSplitPoint); + if (!dynamicPointParseResult.hasValue()) { + int64_t staticSplitPointValue; + if (failed(parser.parseInteger(staticSplitPointValue))) + return failure(); + + staticSplitPoint = + parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); + } else { + if (failed(*dynamicPointParseResult) || + parser.resolveOperand(dynamicSplitPoint, pdlOperationType, + result.operands)) { + return failure(); + } + + staticSplitPoint = + parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); + } + + result.addAttribute( + SplitOp::getStaticSplitPointAttrName(result.name).getValue(), + staticSplitPoint); + if (failed(parser.parseOptionalAttrDict(result.attributes))) + return failure(); + + result.addTypes({pdlOperationType, pdlOperationType}); + return success(); +} + +void SplitOp::print(OpAsmPrinter &printer) { + printer << " " << getTarget() << " after "; + int64_t staticSplitSize = static_cast(getStaticSplitPoint()); + if (staticSplitSize != ShapedType::kDynamicSize) + printer << staticSplitSize; + else + printer << getDynamicSplitPoint(); + printer << " "; + printer.printOptionalAttrDict(getOperation()->getAttrs(), + {getStaticSplitPointAttrName()}); +} + +LogicalResult SplitOp::verify() { + if ((static_cast(getStaticSplitPoint()) != + ShapedType::kDynamicSize) ^ + (getDynamicSplitPoint() == nullptr)) { + return emitOpError() + << "expects either a dynamic or a static split point to be provided"; + } + return success(); +} + //===----------------------------------------------------------------------===// // SplitReductionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -22,6 +22,7 @@ NamedOpConversions.cpp Promotion.cpp SparseTensorRewriting.cpp + Split.cpp SplitReduction.cpp Tiling.cpp TilingInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -0,0 +1,149 @@ +//===- Split.cpp - Structured op splitting --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" + +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::linalg; + +/// Turns an OpFoldResult into a value, creating an index-typed constant if +/// necessary. +static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, + OpFoldResult opFoldResult) { + if (opFoldResult.is()) + return opFoldResult.get(); + auto attr = opFoldResult.get().cast(); + return builder.create(attr.getValue().getSExtValue()); +} + +/// Extract the slices of `operands` supplied to the given operation `op` such +/// that they are sufficient to execute the op for the subset of its iteration +/// space defined by `splitIterationSpace`. The subset is a part of the original +/// iteration space split at the given `dimension`. If `offset` is provided, it +/// indicates the iterator value at which the dimension has been split and +/// requires the "high" part starting at the given offset of the operands to be +/// generated; otherwise, the "low" part with no offset is generated. Note that +/// `operands` are not necessarily the actual operands of `op`. +static SmallVector +getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op, + ValueRange splitIterationSpace, ValueRange operands, + unsigned dimension, Value offset = nullptr) { + SmallVector slices; + slices.reserve(op.getNumInputsAndOutputs()); + for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + auto type = opOperand->get().getType().dyn_cast(); + AffineMap indexing = op.getTiedIndexingMap(opOperand); + if (!type || dimension >= type.getRank() || + !indexing.isFunctionOfDim(dimension)) { + slices.push_back(opOperand->get()); + continue; + } + + SmallVector sizes = + applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace); + SmallVector offsets(type.getRank(), builder.getIndexAttr(0)); + SmallVector strides(type.getRank(), builder.getIndexAttr(1)); + + if (offset) { + // TODO: this is quite unfortunate to convert OpFoldResults to values and + // back because it may be creating spurious constant operations in the + // process. However, the createOrFold API does not expose OpFoldResult so + // we cannot just keep using it. + SmallVector offsetValues = + llvm::to_vector(llvm::map_range(offsets, [&](OpFoldResult ofr) { + return materializeOpFoldResult(builder, ofr); + })); + offsetValues[dimension] = offset; + offsets = getAsOpFoldResult( + applyMapToValues(builder, builder.getLoc(), indexing, offsetValues)); + } + + slices.push_back(createSlice(builder, op.getLoc(), + operands[opOperand->getOperandNumber()], + offsets, getAsOpFoldResult(sizes), strides)); + } + + return slices; +} + +std::pair linalg::splitOp(RewriterBase &rewriter, + LinalgOp op, unsigned dimension, + OpFoldResult splitPoint) { + // Bail out on dimension overflow. + if (dimension >= op.getNumLoops()) + return std::make_pair(op, LinalgOp()); + + // Compute the iteration space size as values. + ImplicitLocOpBuilder builder(op.getLoc(), rewriter); + SmallVector allShapes = + op.createFlatListOfOperandDims(builder, op.getLoc()); + AffineMap shapesToLoops = op.getShapesToLoopsMap(); + SmallVector iterationSpaceShapes = + applyMapToValues(builder, op.getLoc(), shapesToLoops, allShapes); + + // Update the iteration space to have `splitPoint` as the size of `dimension` + // and use it to slice operands and results for a new, smaller instance of the + // `op`. Adjust the size if necessary to prevent overflows. Insert the partial + // results back. + Value splitPointValue = materializeOpFoldResult(builder, splitPoint); + splitPointValue = builder.createOrFold( + builder.getIndexType(), + AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()), + ValueRange({splitPointValue, iterationSpaceShapes[dimension]})); + SmallVector splitIterationSpace = + llvm::to_vector(iterationSpaceShapes); + splitIterationSpace[dimension] = splitPointValue; + SmallVector firstOperands = llvm::to_vector( + llvm::map_range(op.getInputAndOutputOperands(), + [](OpOperand *opOperand) { return opOperand->get(); })); + firstOperands = getOperandSlices(builder, op, splitIterationSpace, + firstOperands, dimension); + SmallVector firstResultTypes = getTensorOutputTypes(op, firstOperands); + Operation *first = + op.clone(rewriter, op.getLoc(), firstResultTypes, firstOperands); + SmallVector firstResults = insertSlicesBack( + builder, builder.getLoc(), op, firstOperands, first->getResults()); + + // Update the iteration space to cover the remaining part of the original + // space, then create another instance of the `op` in that space. The size of + // the remaining part may become zero, but is never negative because of the + // adjustment above. + AffineExpr d0 = builder.getAffineDimExpr(0); + AffineExpr d1 = builder.getAffineDimExpr(1); + SmallVector remainingSizes = applyMapToValues( + builder, op.getLoc(), AffineMap::inferFromExprList({d0 - d1}).front(), + {iterationSpaceShapes[dimension], splitPointValue}); + splitIterationSpace[dimension] = remainingSizes.front(); + SmallVector secondOperands = llvm::to_vector( + llvm::map_range(op.getInputOperands(), + [](OpOperand *opOperand) { return opOperand->get(); })); + llvm::append_range(secondOperands, firstResults); + secondOperands = getOperandSlices(builder, op, splitIterationSpace, + secondOperands, dimension, splitPointValue); + SmallVector secondResultTypes = + getTensorOutputTypes(op, secondOperands); + Operation *second = + op.clone(rewriter, op.getLoc(), secondResultTypes, secondOperands); + SmallVector secondResults = insertSlicesBack( + builder, builder.getLoc(), op, secondOperands, second->getResults()); + + // Fixup the linalg.index results in the second part. + SmallVector ivAdditions; + ivAdditions.resize(splitIterationSpace.size()); + ivAdditions[dimension] = splitPointValue; + linalg::addTileLoopIvsToIndexOpResults(builder, cast(second), + ivAdditions); + + // Replace the original op with the results of the two newly created ops. + rewriter.replaceOp(op, secondResults); + return std::make_pair(cast(first), cast(second)); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -182,32 +182,11 @@ makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds, /*omitPartialTileCheck=*/false); - // TODO: use an interface/adaptor to avoid leaking position in - // `tiledOperands`. - SmallVector resultTensorTypes; - for (OpOperand *opOperand : op.getOutputTensorOperands()) - resultTensorTypes.push_back( - tiledOperands[opOperand->getOperandNumber()].getType()); - + SmallVector resultTensorTypes = + getTensorOutputTypes(op, tiledOperands); res = op.clone(b, loc, resultTensorTypes, tiledOperands); - - // Insert a insert_slice for each output tensor. - unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getOutputTensorOperands()) { - // TODO: use an interface/adaptor to avoid leaking position in - // `tiledOperands`. - Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; - // TODO: Propagate RewriterBase everywhere. - IRRewriter rewriter(b); - if (auto sliceOp = outputTensor.getDefiningOp()) { - tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp, - res->getResult(resultIdx), - sliceOp.getSource())); - } else { - tensorResults.push_back(res->getResult(resultIdx)); - } - ++resultIdx; - } + tensorResults = + insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); return scf::ValueVector(tensorResults.begin(), tensorResults.end()); }; GenerateLoopNest::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -913,6 +913,21 @@ return sliceOp->getResult(0); } +Value createSlice(OpBuilder &builder, Location loc, Value value, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides) { + if (value.getType().isa()) { + return builder.create(loc, value, offsets, sizes, + strides); + } + + // This intentionally does not attempt to compose the extractslice operations. + assert(value.getType().isa() && + "expected either a memref or a ranked tensor type"); + return builder.create(loc, value, offsets, sizes, + strides); +} + SmallVector computeTileOffsets(OpBuilder &b, Location loc, ValueRange ivs, ValueRange tileSizes) { SmallVector offsets; @@ -943,6 +958,41 @@ return sizes; } +SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands) { + // TODO: use an interface/adaptor to avoid leaking position in + // `tiledOperands`. + return llvm::to_vector( + llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) { + return operands[opOperand->getOperandNumber()].getType(); + })); +} + +SmallVector insertSlicesBack(OpBuilder &builder, Location loc, + LinalgOp op, ValueRange operands, + ValueRange results) { + SmallVector tensorResults; + tensorResults.reserve(results.size()); + // Insert a insert_slice for each output tensor. + unsigned resultIdx = 0; + for (OpOperand *opOperand : op.getOutputTensorOperands()) { + // TODO: use an interface/adaptor to avoid leaking position in + // `tiledOperands`. + Value outputTensor = operands[opOperand->getOperandNumber()]; + if (auto sliceOp = outputTensor.getDefiningOp()) { + Value inserted = builder.create( + loc, sliceOp.source().getType(), results[resultIdx], sliceOp.source(), + sliceOp.offsets(), sliceOp.sizes(), sliceOp.strides(), + sliceOp.static_offsets(), sliceOp.static_sizes(), + sliceOp.static_strides()); + tensorResults.push_back(inserted); + } else { + tensorResults.push_back(results[resultIdx]); + } + ++resultIdx; + } + return tensorResults; +} + SmallVector makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, ArrayRef valuesToTile, diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -15,6 +15,12 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] +def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr: + if isinstance(value, int): + return IntegerAttr.get(IntegerType.get_signless(64), value) + return value + + def _get_array_attr( values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr: """Creates an array attribute from its operand.""" @@ -41,13 +47,7 @@ if isinstance(values, ArrayAttr): return values - attributes = [] - for value in values: - if isinstance(value, IntegerAttr): - attributes.append(value) - else: - attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value)) - return ArrayAttr.get(attributes) + return ArrayAttr.get([_get_int64_attr(v) for v in values]) def _get_int_int_array_attr( @@ -152,6 +152,39 @@ pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip) +class SplitOp: + """Specialization for SplitOp class.""" + + def __init__(self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + split_point: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None): + dimension = _get_int64_attr(dimension) + if isinstance(split_point, int): + split_point = _get_int64_attr(split_point) + + if isinstance(split_point, Attribute): + static_split_point = split_point + dynamic_split_point = None + else: + static_split_point = _get_int64_attr(ShapedType._get_dynamic_size()) + dynamic_split_point = _get_op_result_or_value(split_point) + + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, + pdl_operation_type, + _get_op_result_or_value(target), + dimension=dimension, + static_split_point=static_split_point, + dynamic_split_point=dynamic_split_point, + loc=loc, + ip=ip) + + class TileOp: """Specialization for TileOp class.""" diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir @@ -0,0 +1,366 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -verify-diagnostics | FileCheck %s + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1:2 = transform.structured.split %0 after 42 { dimension = 0 } + } +} + +func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 + +// CHECK: #[[$ADD_42_MAP:.+]] = affine_map<(d0) -> (d0 + 42)> +// CHECK: #[[$ADD_10_MAP:.+]] = affine_map<(d0) -> (d0 + 10)> + +// CHECK-LABEL: @one_d_static +// CHECK-SAME: %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32> +func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [42] [1] : tensor<100xf32> to tensor<42xf32> + // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [42] [1] : tensor<100xf32> to tensor<42xf32> + // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_LOW]] + // CHECK: outs(%[[OUT_SLICE_LOW]] + // CHECK: linalg.index 0 + // CHECK: func.call @elem + // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [42] [1] + // + // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][42] [58] [1] : tensor<100xf32> to tensor<58xf32> + // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][42] [58] [1] : tensor<100xf32> to tensor<58xf32> + // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_HIGH]] + // CHECK: outs(%[[OUT_SLICE_HIGH]] + // CHECK: %[[IDX:.+]] = linalg.index 0 + // CHECK: affine.apply #[[$ADD_42_MAP]](%[[IDX]]) + // CHECK: func.call @elem + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][42] [58] [1] + %0 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<100xf32> + + // CHECK: return %[[RES]] + return %0 : tensor<100xf32> +} + +// CHECK-LABEL: @one_d_static_overflow +// CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> +func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { + // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> + // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> + // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_LOW]] + // CHECK: outs(%[[OUT_SLICE_LOW]] + // CHECK: linalg.index 0 + // CHECK: func.call @elem + // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1] + // + // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> + // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> + // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_HIGH]] + // CHECK: outs(%[[OUT_SLICE_HIGH]] + // CHECK: %[[IDX:.+]] = linalg.index 0 + // CHECK: affine.apply #[[$ADD_10_MAP]](%[[IDX]]) + // CHECK: func.call @elem + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][10] [0] [1] + %0 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<10xf32> + return %0 : tensor<10xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @func_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1 = transform.pdl_match @func_call in %arg1 + transform.structured.split %0 after %1 { dimension = 0 } + } +} + +func.func private @get_size() -> index + +// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<(d0, d1) -> (d0, 100)> +// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)> + +// CHECK-LABEL: @dynamic +func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // CHECK: %[[SPLIT:.+]] = call @get_size + // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]](%[[SPLIT]] + // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor + // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor + // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_LOW]] + // CHECK: outs(%[[OUT_SLICE_LOW]] + // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1] + // + // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] + // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] + // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor + // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] + // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor + // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic + // CHECK: ins(%[[IN_SLICE_HIGH]] + // CHECK: outs(%[[OUT_SLICE_HIGH]] + // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] + %0 = func.call @get_size() : () -> index + %1 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%3: f32, %4: f32): + linalg.yield %3 : f32 + } -> tensor<100xf32> + return %1 : tensor<100xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1:2 = transform.structured.split %0 after 4 { dimension = 0} + %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } + } +} + +func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 + +// CHECK-LABEL: @two_d +func.func @two_d(%arg0: tensor<10x34xf32>, + %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> { + // Check the overall structure: split along the dimension 0, and then split + // the second half only along the dimension 1. + // CHECK: %[[IN_1:.+]] = tensor.extract_slice %[[IN:.+]][0, 0] + // CHECK: %[[OUT_1:.+]] = tensor.extract_slice %[[OUT:.+]][0, 0] + // CHECK: %[[RES_1:.+]] = linalg.generic + // CHECK-SAME: ins(%[[IN_1]] : tensor<4x34xf32>) + // CHECK-SAME: outs(%[[OUT_1]] : tensor<4x34xf32>) + // CHECK: %[[PARTIAL_1:.+]] = tensor.insert_slice %[[RES_1]] into %[[OUT]] + // + // CHECK: %[[IN_2:.+]] = tensor.extract_slice %[[IN]] + // CHECK: %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]] + // CHECK: %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]] + // CHECK: %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]] + // CHECK: %[[RES_21:.+]] = linalg.generic + // CHECK-SAME: ins(%[[IN_21]] : tensor<6x16xf32>) + // CHECK-SAME: outs(%[[OUT_21]] : tensor<6x16xf32>) + // CHECK: %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]] + // + // CHECK: %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]] + // CHECK: %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]] + // CHECK: %[[RES_22:.+]] = linalg.generic + // CHECK-SAME: ins(%[[IN_22]] : tensor<6x18xf32>) + // CHECK-SAME: outs(%[[OUT_22]] : tensor<6x18xf32>) + // CHECK: %[[PARTIAL_22:.+]] = tensor.insert_slice %[[RES_22]] into %[[PARTIAL_21]] + // CHECK: %[[PARTIAL_2:.+]] = tensor.insert_slice %[[PARTIAL_22]] into %[[PARTIAL_1]] + %0 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"] + } + ins(%arg0: tensor<10x34xf32>) + outs(%arg1: tensor<10x34xf32>) { + ^bb0(%0: f32, %1: f32): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<10x34xf32> + return %0 : tensor<10x34xf32> +} + +// ----- + +transform.sequence { +^bb1(%arg1: !pdl.operation): + // expected-error @below {{expects either a dynamic or a static split point to be provided}} + %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -1 } : (!pdl.operation) -> (!pdl.operation, !pdl.operation) +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @func_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1 = transform.pdl_match @func_call in %arg1 + // expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}} + transform.structured.split %0 after %1 { dimension = 0 } + } +} + +func.func private @get_size() -> i64 + +func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-note @below {{dynamic split point}} + %0 = func.call @get_size() : () -> i64 + %1 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%3: f32, %4: f32): + linalg.yield %3 : f32 + } -> tensor<100xf32> + return %1 : tensor<100xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @func_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + %1 = transform.pdl_match @func_call in %arg1 + // expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}} + transform.structured.split %0 after %1 { dimension = 0 } + } +} + +func.func private @get_size() -> i64 + +func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + %1 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%3: f32, %4: f32): + linalg.yield %3 : f32 + } -> tensor<100xf32> + return %1 : tensor<100xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @func_return : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.return"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @func_return in %arg1 + // expected-error @below {{only applies to structured ops}} + transform.structured.split %0 after 16 { dimension = 1 } + } +} + +func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-note @below {{target op}} + return %arg0 : tensor<100xf32> +} + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @linalg_generic : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "linalg.generic"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = transform.pdl_match @linalg_generic in %arg1 + // expected-error @below {{dimension 1 does not exist is target op}} + transform.structured.split %0 after 16 { dimension = 1 } + } +} + +func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-note @below {{target op}} + %0 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%0: f32, %1: f32): + linalg.yield %0 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -84,6 +84,19 @@ # CHECK: transform.structured.scalarize +@run +def testSplit(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42) + structured.SplitOp( + split.results[0], dimension=3, split_point=split.results[1]) + transform.YieldOp() + # CHECK-LABEL: TEST: testSplit + # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 + # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 + + @run def testTileCompact(): sequence = transform.SequenceOp()