diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -153,6 +153,25 @@ }]; } +def SplitReductionOp : Op, + FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { + let description = [{ + Indicates that the given `target` op should be transformed with the + `splitReduction` transformation and split factor provided as attribute. + + This op returns handles to the split op and the result-combining op. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$split_factor, + DefaultValuedAttr:$insert_split_dimension); + let results = (outs PDL_Operation:$split_linalg_op, + PDL_Operation:$combining_linalg_op); + + let assemblyFormat = "$target attr-dict"; +} + def TileOp : Op, FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1466,6 +1466,7 @@ /// reduction dimension. The dimension index is used to control where the extra /// dimension is added to the intermediate tensor shape. If the ratio value is /// less or equal to 1 then nothing will be done. +// TODO: don't use unsigned unless doing bit manipulation. using ControlSplitReductionFn = std::function(LinalgOp op)>; @@ -1475,11 +1476,15 @@ const ControlSplitReductionFn &controlSplitReductionFn, const LinalgTransformationFilter &f = LinalgTransformationFilter()); -/// Apply transformation to split the single linalg op reduction into a parallel -/// and reduction dimension. Then create a new linalg.generic op doing the rest -/// of the reduction. Return the new linalg op with an extra parallel dimension -/// or failure if the transformation didn't happen. +/// Apply transformation to split a single reduction dimension of a linalg op +/// into a pair of (parallel + reduction) dimensions. +/// Subsequently, create a new linalg.generic op further combining the parallel +/// reduced pieces. +/// Return the new linalg op with an extra parallel dimension or failure if the +/// transformation didn't happen. +/// /// Example: +/// /// ``` /// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, /// affine_map<(d0) -> ()>], @@ -1491,7 +1496,9 @@ /// linalg.yield %y : f32 /// } -> tensor /// ``` -/// To: +/// +/// may be rewritten to: +/// /// ``` /// %cst = arith.constant 0.000000e+00 : f32 /// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32> @@ -1519,6 +1526,16 @@ const ControlSplitReductionFn &controlSplitReductionFn, const LinalgTransformationFilter &f); +struct SplitReductionResult { + LinalgOp splitLinalgOp; + LinalgOp resultCombiningLinalgOp; +}; +/// Filterless version of the above. +/// Returns both the `splitLinalgOp` and the `resultCombiningLinalgOp`. +FailureOr +splitReduction(PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -240,6 +240,22 @@ getContext()); } + /// Returns a new AffineMap with the same number of dims and symbols and one + /// less result at `pos`, dropped. + AffineMap dropResult(unsigned pos) { + auto exprs = llvm::to_vector<4>(getResults()); + exprs.erase(exprs.begin() + pos); + return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); + } + + /// Returns a new AffineMap with the same number of dims and symbols and an + /// extra result inserted at `pos`. + AffineMap insertResult(AffineExpr expr, unsigned pos = 0) { + auto exprs = llvm::to_vector<4>(getResults()); + exprs.insert(exprs.begin() + pos, expr); + return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); + } + /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. LogicalResult constantFold(ArrayRef operandConstants, diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -249,6 +249,16 @@ return *this; } + /// Insert a dim into shape @pos. + Builder &insertDim(unsigned pos, int64_t val) { + assert(pos <= shape.size() && "overflow"); + if (storage.empty()) + storage.append(shape.begin(), shape.end()); + storage.insert(storage.begin() + pos, val); + shape = {storage.data(), storage.size()}; + return *this; + } + operator RankedTensorType() { return RankedTensorType::get(shape, elementType, encoding); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -394,6 +394,41 @@ return result->op; } +//===----------------------------------------------------------------------===// +// SplitReductionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::SplitReductionOp::apply(TransformResults &transformResults, + TransformState &state) { + ControlSplitReductionFn splitFn = [&](LinalgOp _) { + return std::pair(getSplitFactor(), + getInsertSplitDimension()); + }; + + SimpleRewriter rewriter(getContext()); + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto linalgOp = dyn_cast_or_null(payloadOps.front()); + if (payloadOps.size() != 1 || !linalgOp) { + getOperation()->emitError("only single LinalgOp payload supported"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + FailureOr splitResult = + splitReduction(rewriter, linalgOp, splitFn); + if (failed(splitResult)) { + getOperation()->emitError("failed to apply"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + transformResults.set(getOperation()->getOpResult(0), + splitResult->splitLinalgOp.getOperation()); + transformResults.set(getOperation()->getOpResult(1), + splitResult->resultCombiningLinalgOp.getOperation()); + + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -19,13 +19,14 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::linalg; -/// Return the identity numeric value associated to the give op. -static Optional getIdentity(Operation *op) { +/// Return the neutral numeric value associated to the give op. +static Attribute getNeutralElement(Operation *op) { // Builder only used as helper for attribute creation. OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); @@ -41,7 +42,7 @@ if (isa(op)) return b.getFloatAttr(resultType, llvm::APFloat::getLargest(semantic, true)); - return llvm::None; + return Attribute(); } if (isa(op)) return b.getIntegerAttr(resultType, 0); @@ -53,9 +54,10 @@ return b.getIntegerAttr(resultType, std::numeric_limits::max()); if (isa(op)) return b.getIntegerAttr(resultType, 1); - return llvm::None; + return Attribute(); } +/// Wrap the core rewrite logic with filter attribute set/update. FailureOr mlir::linalg::splitReduction( PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, @@ -64,148 +66,203 @@ op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || !op.hasOnlyProjectedPermutations()) return b.notifyMatchFailure(op, "precondition not met"); + + auto res = splitReduction(b, op, controlSplitReductionFn); + if (failed(res)) + return failure(); + + filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp); + filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp); + + return res->splitLinalgOp; +} + +/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + k', ...) +/// Rewrite f(i, j, k * ratio + k', ...) into f(i, j, k * ratio + k', ...) with +/// a proper ExpandShapeOp +static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, + unsigned reductionDimPos, + int64_t reductionRatio) { + auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); + auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext()); + AffineMap map = op.getTiedIndexingMap(&opOperand); + AffineMap idMap = + AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); + AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); + AffineMap composeMap = shiftedIdMap.replace( + reductionDim, reductionDim * reductionRatio + reductionDimP1, + shiftedIdMap.getNumDims(), /*numSymbols=*/0); + return map.compose(composeMap); +} + +static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, + unsigned reductionDimPos, int64_t size) { + auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); + AffineMap map = op.getTiedIndexingMap(&opOperand); + AffineMap idMap = + AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); + AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); + return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos); +} + +/// Core rewrite implementation. +FailureOr mlir::linalg::splitReduction( + PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(op); + + // Matcher part, enforce preconditions. std::pair control = controlSplitReductionFn(op); int64_t ratio = control.first; unsigned insertDimIndex = control.second; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); + SmallVector dims; op.getReductionDims(dims); - assert(dims.size() == 1); - unsigned reductionDim = dims[0]; - SmallVector loopRanges = op.getStaticLoopRanges(); - int64_t reductionDimSize = loopRanges[reductionDim]; + if (dims.empty()) + return b.notifyMatchFailure(op, "needs at least 1 reduction dimension"); + + unsigned reductionDimPos = dims[0]; + SmallVector loopRanges = op.getStaticLoopRanges(); + int64_t reductionDimSize = loopRanges[reductionDimPos]; if (reductionDimSize == ShapedType::kDynamicSize || reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size()) return b.notifyMatchFailure( - op, "Reduction dimension not divisible by split ratio"); - SmallVector combinerOps; - if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || - combinerOps.size() != 1) - return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); - Operation *reductionOp = combinerOps[0]; - Optional identity = getIdentity(reductionOp); - if (!identity) - return b.notifyMatchFailure(op, "Unknown identity value for the redution"); + op, "first reduction dimension not divisible by split ratio"); + + SmallVector combinerOps; + if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) + return b.notifyMatchFailure(op, "cannot match a reduction pattern"); + SmallVector identities = llvm::to_vector<4>( + llvm::map_range(combinerOps, [&](Operation *reductionOp) { + return getNeutralElement(reductionOp); + })); + if (!llvm::all_of(identities, [](Attribute attr) { return attr; })) + return b.notifyMatchFailure(op, "unknown reduction neutral"); + + // TODO: relax this when multi-reduction support is available. + if (op.getNumOutputs() != identities.size()) + return b.notifyMatchFailure(op, "expect one reduction per output"); + + // Rewrite part. + // Step 1. Build the outputs filled with the proper identities. Location loc = op->getLoc(); - SmallVector newInputs; + MLIRContext *context = op.getContext(); + // For now assume outputs are 1-1 with reduction identities. + // TODO: generalize when multi-reduction support is available. + SmallVector newOutputs; + newOutputs.reserve(op.getNumOutputs()); + for (auto it : llvm::zip(op.outputs(), identities)) { + Value rankedTensor = std::get<0>(it); + auto t = rankedTensor.getType().cast(); + RankedTensorType newT = RankedTensorType::Builder(t).insertDim( + insertDimIndex, reductionDimSize / ratio); + SmallVector dims = + tensor::createDynamicDimValues(b, loc, rankedTensor); + Value initTensor = b.create( + loc, dims, newT.getShape(), t.getElementType()); + Value constantOp = b.create(loc, std::get<1>(it)); + newOutputs.push_back( + b.create(op->getLoc(), constantOp, initTensor) + .getResult(0)); + } + + // Step 2. Reindex / expand indexing maps. + // Reindex existing input indexings: k -> k * ratio + k'. SmallVector newMaps; - // Calculate the new shapes and indexing maps of the input operands. - for (OpOperand *operand : op.getInputOperands()) { - AffineMap map = op.getTiedIndexingMap(operand); - SmallVector newShape; - SmallVector exprs; - SmallVector reassociation; - unsigned index = 0; - for (unsigned idx : llvm::seq(0, map.getNumResults())) { - unsigned dim = map.getDimPosition(idx); - if (reductionDim == dim) { - newShape.push_back(ratio); - newShape.push_back(op.getShape(operand)[idx] / ratio); - reassociation.push_back({index++, index++}); - exprs.push_back(b.getAffineDimExpr(insertDimIndex)); - exprs.push_back( - b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); - continue; - } - newShape.push_back(op.getShape(operand)[idx]); - exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); - reassociation.push_back({index++}); - } + newMaps.reserve(op.getNumInputsAndOutputs() + 1); + for (OpOperand *o : op.getInputOperands()) + newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, ratio)); + // Provision a new indexing for the shape-only tensor. + auto nDims = op.getNumLoops() + 1; + auto redDim = getAffineDimExpr(reductionDimPos, context); + auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context); + newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context)); + // Expand existing output indexings. + // TODO: a subset of these may not reduce along reducePos and should be + // reindexed: k -> k * ratio + k', when multi-reduction support is available. + for (OpOperand *o : op.getOutputOperands()) newMaps.push_back( - AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); - // If the shape is unchanged the input doesn't change. - if (newShape == op.getShape(operand)) { - newInputs.push_back(operand->get()); - continue; - } - Type newType = RankedTensorType::get( - newShape, - operand->get().getType().cast().getElementType()); - Value newInput = b.create( - loc, newType, operand->get(), reassociation); - newInputs.push_back(newInput); - } - // Calculate the new output map and shape, we insert the new dimension based - // on the index returned by `controlSplitReductionFn`. - SmallVector newOutputShape; - AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); - ArrayRef oldShape = op.getShape(op.getOutputOperand(0)); - SmallVector outputExpr; - for (unsigned idx : - llvm::seq(0, oldOutputMap.getNumResults() + 1)) { - if (idx == insertDimIndex) { - newOutputShape.push_back(ratio); - outputExpr.push_back(b.getAffineDimExpr(insertDimIndex)); - continue; - } - unsigned oldDim = idx < insertDimIndex ? idx : idx - 1; - newOutputShape.push_back(oldShape[oldDim]); - unsigned dim = oldOutputMap.getDimPosition(oldDim); - outputExpr.push_back( - b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); - } - Value initTensor = b.create( - loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); - Value constantOp = b.create(loc, *identity); - Value identityTensor = - b.create(op->getLoc(), constantOp, initTensor) - .getResult(0); - - newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, - op.getContext())); - SmallVector newIteratorTypes; - for (auto &it : llvm::enumerate(op.iterator_types())) { - if (insertDimIndex == it.index()) - newIteratorTypes.push_back(getParallelIteratorTypeName()); - newIteratorTypes.push_back(it.value().cast().getValue()); - } - // Create the new op matching the original op with an extra parallel + insertParallelDim(op, *o, reductionDimPos, reductionDimSize / ratio)); + + // Step 3. Handle operands. + // Compute the new input tensors. + auto newInputs = llvm::to_vector<4>(op.inputs()); + // Add a single shape-only tensor to carry the dimensions without resorting to + // more complex inversions. + newInputs.push_back(b.create( + loc, ArrayRef{reductionDimSize / ratio, ratio}, + b.getIntegerType(1))); + // Output tensors are already good to go. + + // Step 4. Create the new op matching the original op with an extra parallel // dimension. - GenericOp genericOp = b.create( - loc, TypeRange({initTensor.getType()}), newInputs, - ValueRange({identityTensor}), newMaps, newIteratorTypes); + SmallVector iteratorTypes = + llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange()); + iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, + getParallelIteratorTypeName()); + GenericOp genericOp = + b.create(loc, ValueRange(newOutputs).getTypes(), newInputs, + newOutputs, newMaps, iteratorTypes); b.inlineRegionBefore(op->getRegion(0), genericOp.region(), genericOp.region().begin()); + genericOp.region().front().insertArgument(reductionDimPos, + b.getIntegerType(1), loc); + + // Step 5. Create new reduction ops that only reduce the newly added + // dimensions from the previous op. + // For now assume outputs are 1-1 with reduction ops. + // TODO: a subset of these may not reduce in the first place and do not + // require a new op, when multi-reduction support is available. + // TODO: all results can be handled in a single GenericOp, when + // multi-reduction support is available. + SmallVector results; + for (auto it : llvm::zip(genericOp.outputs(), op.outputs(), combinerOps)) { + Value reindexedOutput = std::get<0>(it); + Value originalOutput = std::get<1>(it); + auto originalOutputType = originalOutput.getType().cast(); + Operation *combinerOp = std::get<2>(it); - // Then create a new reduction that only reduce the newly added dimension from - // the previous op. - unsigned intermRank = newOutputShape.size(); - AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); - SmallVector outputOperands = op.getOutputOperands(); - SmallVector reductionIteratorTypes; - SmallVector exprs; - for (unsigned i : llvm::seq(0, intermRank)) { - if (insertDimIndex == i) { - reductionIteratorTypes.push_back(getReductionIteratorTypeName()); - } else { - exprs.push_back(b.getAffineDimExpr(i)); - reductionIteratorTypes.push_back(getParallelIteratorTypeName()); - } + AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); + SmallVector indexingMaps = {map, map.dropResult(insertDimIndex)}; + SmallVector reductionIteratorTypes( + originalOutputType.getRank() + 1, getParallelIteratorTypeName()); + reductionIteratorTypes[insertDimIndex] = getReductionIteratorTypeName(); + + // clang-format off + auto reductionOp = b.create( + loc, + originalOutputType, + reindexedOutput, + originalOutput, + indexingMaps, + reductionIteratorTypes, + [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) { + Operation *clonedReductionOp = b.clone(*combinerOp); + clonedReductionOp->setOperand(0, bbArgs[0]); + clonedReductionOp->setOperand(1, bbArgs[1]); + b.create(loc, clonedReductionOp->getResult(0)); + }); + // clang-format on + + results.push_back(reductionOp); } - AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); - SmallVector reductionMaps = {inputMap, outputMap}; - - auto reduction = b.create( - loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), - outputOperands, reductionMaps, reductionIteratorTypes, - [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { - Operation *clonedReductionOp = b.clone(*reductionOp); - clonedReductionOp->setOperand(0, inputs[0]); - clonedReductionOp->setOperand(1, inputs[1]); - b.create(loc, clonedReductionOp->getResult(0)); - }); - b.replaceOp(op, reduction.getResults()); - filter.replaceLinalgTransformationFilter(b, genericOp); - filter.replaceLinalgTransformationFilter(b, reduction); - return cast(genericOp.getOperation()); + + // TODO: extend when multi-reduction support is available. + assert(results.size() == 1); + b.replaceOp(op, results.front()->getResults()); + return SplitReductionResult{cast(genericOp.getOperation()), + results.front()}; } namespace { struct LinalgSplitReduction : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. + /// Construct a generic pattern applied to all LinalgOp that verify + /// `filter`. LinalgSplitReduction(MLIRContext *context, ControlSplitReductionFn controlSplitReductionFn, LinalgTransformationFilter f, PatternBenefit benefit = 1) diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s + +// CHECK-LABEL: func.func @matmul_split +func.func @matmul_split(%A : tensor, %B: tensor<256x32xf32>, %C: tensor) -> tensor { + + // CHECK: linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] + + // CHECK: linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] + + // Atm the following IR is generated. + // func.func @matmul_split(%arg0: tensor, %arg1: tensor<256x32xf32>, %arg2: tensor) -> tensor { + // %c0 = arith.constant 0 : index + // %0 = tensor.dim %arg2, %c0 : tensor + // %1 = linalg.init_tensor [%0, 32, 64] : tensor + // %cst = arith.constant 0.000000e+00 : f32 + // %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + // %3 = linalg.init_tensor [64, 4] : tensor<64x4xi1> + // %4 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1, %3 : tensor, tensor<256x32xf32>, tensor<64x4xi1>) outs(%2 : tensor) { + // ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32): + // %6 = arith.mulf %arg3, %arg4 : f32 + // %7 = arith.addf %arg6, %6 : f32 + // linalg.yield %7 : f32 + // } -> tensor + // %5 = linalg.generic {indexing_maps = [#map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} ins(%2 : tensor) outs(%arg2 : tensor) { + // ^bb0(%arg3: f32, %arg4: f32): + // %6 = arith.addf %arg3, %arg4 : f32 + // linalg.yield %6 : f32 + // } -> tensor + // return %5 : tensor + // } + %0 = linalg.matmul ins(%A, %B: tensor, tensor<256x32xf32>) + outs(%C: tensor) -> tensor + return %0: tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1:2 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2} + } +}