diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -33,7 +33,7 @@ /// This additionally takes a TransitiveFilter which acts as a frontier: /// when looking at uses transitively, an operation that does not pass the /// filter is never propagated through. This allows in particular to carve out -/// the scope within a ForInst or the scope within an IfInst. +/// the scope within a ForOp or the scope within an IfOp. /// /// The implementation traverses the use chains in postorder traversal for /// efficiency reasons: if an operation is already in `forwardSlice`, no @@ -82,7 +82,7 @@ /// This additionally takes a TransitiveFilter which acts as a frontier: /// when looking at defs transitively, an operation that does not pass the /// filter is never propagated through. This allows in particular to carve out -/// the scope within a ForInst or the scope within an IfInst. +/// the scope within a ForOp or the scope within an IfOp. /// /// The implementation traverses the def chains in postorder traversal for /// efficiency reasons: if an operation is already in `backwardSlice`, no diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -286,6 +286,70 @@ }]; } +def Vector_MultiDimReductionOp : + Vector_Op<"multi_reduction", [NoSideEffect, + PredOpTrait<"source operand and result have same element type", + TCresVTEtIsSameAsOpBase<0, 0>>]>, + Arguments<(ins Vector_CombiningKindAttr:$kind, + AnyVector:$source, + I64ArrayAttr:$reduction_dims)>, + Results<(outs AnyType:$dest)> { + let summary = "Multi-dimensional reduction operation"; + let description = [{ + Reduces an n-D vector into an (n-k)-D vector using the given operation + (add/mul/min/max for int/fp and and/or/xor for int only). + + Example: + + ```mlir + %1 = vector.multi_reduction "add", %0 [1, 3] : + vector<4x8x16x32xf32> into vector<4x16xf32> + ``` + }]; + let builders = [ + OpBuilder<(ins "Value":$source, "ArrayRef":$reductionMask, + "CombiningKind":$kind)> + ]; + let extraClassDeclaration = [{ + static StringRef getKindAttrName() { return "kind"; } + static StringRef getReductionDimsAttrName() { return "reduction_dims"; } + + VectorType getSourceVectorType() { + return source().getType().cast(); + } + VectorType getDestVectorType() { + return dest().getType().cast(); + } + + SmallVector getReductionMask() { + SmallVector res(getSourceVectorType().getRank(), false); + for (auto ia : reduction_dims().getAsRange()) + res[ia.getInt()] = true; + return res; + } + static SmallVector getReductionMask( + ArrayRef reductionDims, unsigned sourceRank) { + SmallVector res(sourceRank, false); + for (auto idx : reductionDims) + res[idx] = true; + return res; + } + + static SmallVector inferDestShape( + ArrayRef shape, ArrayRef reducedDimsMask) { + assert(shape.size() == reducedDimsMask.size() && + "shape and maks of different sizes"); + SmallVector res; + for (auto it : llvm::zip(reducedDimsMask, shape)) + if (!std::get<0>(it)) + res.push_back(std::get<1>(it)); + return res; + } + }]; + let assemblyFormat = + "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)"; +} + def Vector_BroadcastOp : Vector_Op<"broadcast", [NoSideEffect, PredOpTrait<"source operand and result have same element type", @@ -1317,6 +1381,18 @@ "ArrayAttr":$inBounds)> ]; + let extraClassDeclaration = [{ + /// Return a new `result` map with `0` inserted in the proper positions so + /// that vector.transfer_read `result` produces a vector of same element + /// type as `vt` and shape `targetShape. + /// Assume that `map` is a permutation map for a vector.transfer_read op, + /// `vt` the vector type produced by the vector.transfer_read and + /// `targetShape` is the desired `targetShape` for a broadcast version of + /// `vt`. + static AffineMap insertBroadcasts(AffineMap map, VectorType vt, + ArrayRef targetShape); + }]; + let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -25,6 +26,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -40,7 +42,8 @@ /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. -template static OpType getSingleOpOfType(Block &block) { +template +static OpType getSingleOpOfType(Block &block) { OpType res; block.walk([&](OpType op) { if (res) { @@ -53,6 +56,31 @@ return res; } +/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a +/// projectedPermutation, compress the unused dimensions to serve as a +/// permutation_map for a vector transfer operation. +/// For example, given a linalg op such as: +/// +/// ``` +/// %0 = linalg.generic { +/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>, +/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)> +/// } +/// ins(%0 : tensor<2x3x4xf32>) +/// outs(%1 : tensor<5x6xf32>) +/// ``` +/// +/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine +/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second +/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`. +static AffineMap reindexIndexingMap(AffineMap map) { + assert(map.isProjectedPermutation() && "expected projected permutation"); + auto res = compressUnusedDims(map); + assert(res.getNumDims() == res.getNumResults() && + "expected reindexed map with same number of dims and results"); + return res; +} + /// Helper data structure to represent the result of vectorization. /// In certain specific cases, like terminators, we do not want to propagate/ enum VectorizationStatus { @@ -83,6 +111,109 @@ return VectorType::get(st.getShape(), st.getElementType()); } +/// Given an `outputOperand` of a LinalgOp, compute the intersection of the +/// forward slice starting from `outputOperand` and the backward slice +/// starting from the corresponding linalg.yield operand. +/// This intersection is assumed to have a single binary operation that is +/// the reduction operation. Multiple reduction operations would impose an +/// ordering between reduction dimensions and is currently unsupported in +/// Linalg. This limitation is motivated by the fact that e.g. +/// min(max(X)) != max(min(X)) +// TODO: use in LinalgOp verification, there is a circular dependency atm. +static Operation *getSingleBinaryOpAssumedReduction(OpOperand &outputOperand) { + auto linalgOp = cast(outputOperand.getOwner()); + auto yieldOp = cast(linalgOp->getRegion(0).front().getTerminator()); + unsigned yieldNum = + outputOperand.getOperandNumber() - linalgOp.getNumInputs(); + llvm::SetVector backwardSlice, forwardSlice; + BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument( + outputOperand.getOperandNumber()); + Value yieldVal = yieldOp->getOperand(yieldNum); + getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) { + return op->getParentOp() == linalgOp; + }); + backwardSlice.insert(yieldVal.getDefiningOp()); + getForwardSlice(bbArg, &forwardSlice, + [&](Operation *op) { return op->getParentOp() == linalgOp; }); + // Search for the (assumed unique) elementwiseMappable op at the intersection + // of forward and backward slices. + Operation *reductionOp = nullptr; + for (Operation *op : llvm::reverse(backwardSlice)) { + if (!forwardSlice.contains(op)) + continue; + if (OpTrait::hasElementwiseMappableTraits(op)) { + if (reductionOp) { + // Reduction detection fails: found more than 1 elementwise-mappable op. + return nullptr; + } + reductionOp = op; + } + } + // TODO: also assert no other subsequent ops break the reduction. + return reductionOp; +} + +/// If `value` of assumed VectorType has a shape different than `shape`, try to +/// build and return a new vector.broadcast to `shape`. +/// Otherwise, just return `value`. +// TODO: this is best effort atm and there is currently no guarantee of +// correctness for the broadcast semantics. +static Value broadcastIfNeeded(OpBuilder &builder, Value value, + ArrayRef shape) { + unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(), + [](int64_t val) { return val > 1; }); + auto vecType = value.getType().dyn_cast(); + if (shape.empty() || + (vecType != nullptr && + (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne))) + return value; + auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() + : value.getType()); + return builder.create( + builder.getInsertionPoint()->getLoc(), newVecType, value); +} + +/// If value of assumed VectorType has a shape different than `shape`, build and +/// return a new vector.broadcast to `shape`. +/// Otherwise, just return value. +static Value reduceIfNeeded(OpBuilder &builder, VectorType targetVectorType, + Value value, OpOperand &outputOperand) { + assert(targetVectorType.getShape() == + outputOperand.get().getType().cast().getShape()); + auto vecType = value.getType().dyn_cast(); + if (vecType.getShape() == targetVectorType.getShape()) + return value; + // At this point, we know we need to reduce. Detect the reduction operator. + // TODO: Use the generic reduction detection util. + Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand); + assert(reductionOp && "expected reduction op."); + auto linalgOp = cast(outputOperand.getOwner()); + unsigned pos = 0; + MLIRContext *ctx = builder.getContext(); + SmallVector exprs; + for (auto s : linalgOp.iterator_types()) + if (isParallelIterator(s)) + exprs.push_back(getAffineDimExpr(pos++, ctx)); + auto loc = reductionOp->getLoc(); + // TODO: reuse common CombiningKing logic and support more than add. + auto kind = llvm::TypeSwitch(reductionOp) + .Case( + [&](auto op) { return vector::CombiningKind::ADD; }) + .Default([&](auto op) { + llvm_unreachable("Unsupported reduction"); + return vector::CombiningKind::ADD; + }); + unsigned idx = 0; + SmallVector reductionMask(linalgOp.iterator_types().size(), false); + for (auto attr : linalgOp.iterator_types()) { + if (isReductionIteratorType(attr)) + reductionMask[idx] = true; + ++idx; + } + return builder.create(loc, value, reductionMask, + kind); +} + /// Build a vector.transfer_read from `source` at indices set to all `0`. /// If source has rank zero, build an memref.load. /// Return the produced value. @@ -90,29 +221,30 @@ VectorType vectorType, AffineMap map) { edsc::ScopedContext scope(builder); auto shapedType = source.getType().cast(); - if (vectorType) { - SmallVector indices(shapedType.getRank(), std_constant_index(0)); - if (map) - return vector_transfer_read(vectorType, source, indices, map); - return vector_transfer_read(vectorType, source, indices); - } - return memref_load(source); + SmallVector indices(shapedType.getRank(), std_constant_index(0)); + return vector_transfer_read(vectorType, source, indices, map); } -/// Build a vector.transfer_write of `value` into `dest` at indices set to all -/// `0`. If `dest` has null rank, build an memref.store. +/// Build a vector.transfer_write of `value` into `outputOperand` at indices set +/// to all `0`; where `outputOperand` is an output operand of the LinalgOp +/// currently being vectorized. If `dest` has null rank, build an memref.store. /// Return the produced value or null if no value is produced. -static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) { +static Value buildVectorWrite(OpBuilder &builder, Value value, + OpOperand &outputOperand) { edsc::ScopedContext scope(builder); Operation *write; - auto shapedType = dest.getType().cast(); - if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) { + auto shapedType = outputOperand.get().getType().cast(); + if (VectorType vectorType = + extractVectorTypeFromShapedValue(outputOperand.get())) { + auto linalgOp = cast(outputOperand.getOwner()); + AffineMap map = reindexIndexingMap( + linalgOp.getIndexingMap(outputOperand.getOperandNumber())); SmallVector indices(shapedType.getRank(), std_constant_index(0)); - if (vectorType != value.getType()) - value = vector_broadcast(vectorType, value); - write = vector_transfer_write(value, dest, indices); + value = broadcastIfNeeded(builder, value, vectorType.getShape()); + value = reduceIfNeeded(builder, vectorType, value, outputOperand); + write = vector_transfer_write(value, outputOperand.get(), indices, map); } else { - write = memref_store(value, dest); + write = memref_store(value, outputOperand.get()); } LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write); if (!write->getResults().empty()) @@ -120,20 +252,6 @@ return Value(); } -/// If value of assumed VectorType has a shape different than `shape`, buil and -/// return a new vector.broadcast to `shape`. -/// Otherwise, just return value. -static Value broadcastIfNeeded(OpBuilder &builder, Value value, - ArrayRef shape) { - auto vecType = value.getType().dyn_cast(); - if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape)) - return value; - auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType() - : value.getType()); - return builder.create( - builder.getInsertionPoint()->getLoc(), newVecType, value); -} - // Custom vectorization function type. Produce a vector form of Operation* // assuming all its vectorized operands are already in the BlockAndValueMapping. // Return nullptr if the Operation cannot be vectorized. @@ -158,8 +276,8 @@ // TODO: Scan for an opportunity for reuse. // TODO: use a map. Value vectorValue = bvm.lookup(outputs.value()); - Value newResult = buildVectorWrite(builder, vectorValue, - linalgOp.getOutput(outputs.index())); + Value newResult = buildVectorWrite( + builder, vectorValue, linalgOp.getOutputOpOperands()[outputs.index()]); if (newResult) newResults.push_back(newResult); } @@ -307,20 +425,6 @@ return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0)); } -// Calculate the map to apply to transfer_read to convert the input shape into -// the output shape. -static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) { - AffineMap linalgMap = linalgOp.getIndexingMap(argIndex); - MLIRContext *context = linalgMap.getContext(); - AffineExpr zero = mlir::getAffineConstantExpr(0, context); - SmallVector exprs(linalgMap.getNumInputs(), zero); - for (unsigned i : llvm::seq(unsigned(0), linalgMap.getNumResults())) { - exprs[linalgMap.getDimPosition(i)] = getAffineDimExpr(i, context); - } - return AffineMap::get(linalgMap.getNumResults(), /*symbolCount=*/0, exprs, - context); -} - /// Generic vectorization function that rewrites the body of a `linalgOp` into /// vector form. Generic vectorization proceeds as follows: /// 1. Verify the `linalgOp` has one non-empty region. @@ -333,42 +437,70 @@ /// 4b. Register CustomVectorizationHook for IndexOp to access the iteration /// indices. /// 5. Iteratively call vectorizeOneOp on the region operations. +/// +/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is +/// performed to the maximal common vector size implied by the `linalgOp` +/// iteration space. This eager broadcasting is introduced in the +/// permutation_map of the vector.transfer_read operations. The eager +/// broadcasting makes it trivial to detrmine where broadcast, transposes and +/// reductions should occur, without any bookkeeping. The tradeoff is that, in +/// the absence of good canonicalizations, the amount of work increases. +/// This is not deemed a problem as we expect canonicalizations and foldings to +/// aggressively clean up the useless work. LogicalResult vectorizeAsLinalgGeneric( OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl &newResults, + bool broadcastToMaximalCommonShape = false, ArrayRef customVectorizationHooks = {}) { // 1. Fail to vectorize if the operation does not have one non-empty region. if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty()) return failure(); auto &block = linalgOp->getRegion(0).front(); - BlockAndValueMapping bvm; // 2. Values defined above the region can only be broadcast for now. Make them // map to themselves. + BlockAndValueMapping bvm; SetVector valuesSet; mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef()); + if (linalgOp.getNumOutputs() == 0) + return failure(); + + // TODO: the common vector shape is equal to the static loop sizes only when + // all indexing maps are projected permutations. For convs and stencils the + // logic will need to evolve. + SmallVector commonVectorShape = linalgOp.computeStaticLoopSizes(); + // 3. Turn all BBArgs into vector.transfer_read / load. SmallVector indexings; for (auto bbarg : block.getArguments()) { - Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber()); - AffineMap map; - VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg); - if (isElementwise(linalgOp) && - !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) { - // Currently assume we don't support output permutations. - assert(linalgOp.getNumOutputs() > 0 && - linalgOp.getOutputIndexingMap(0).isIdentity()); - ArrayRef outputShape = - linalgOp.getOutputShapedType(0).getShape(); - vectorType = VectorType::get(outputShape, vectorType.getElementType()); - map = getTransferReadMap(linalgOp, bbarg.getArgNumber()); + Value shapedArg = linalgOp.getShapedOperand(bbarg.getArgNumber()); + ShapedType shapedType = shapedArg.getType().cast(); + // TODO: 0-d vectors. + if (shapedType.getShape().empty()) { + Value loaded = + builder.create(linalgOp.getLoc(), shapedArg); + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" + << bbarg.getArgNumber() << "): " << loaded); + bvm.map(bbarg, loaded); + bvm.map(shapedArg, loaded); + continue; + } + AffineMap map = inversePermutation( + reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber()))); + VectorType vectorType = VectorType::get(map.compose(shapedType.getShape()), + shapedType.getElementType()); + if (broadcastToMaximalCommonShape) { + map = vector::TransferReadOp::insertBroadcasts(map, vectorType, + commonVectorShape); + vectorType = + VectorType::get(commonVectorShape, vectorType.getElementType()); } - Value vectorRead = buildVectorRead(builder, vectorArg, vectorType, map); + Value vectorRead = buildVectorRead(builder, shapedArg, vectorType, map); LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" << bbarg.getArgNumber() << "): " << vectorRead); bvm.map(bbarg, vectorRead); - bvm.map(vectorArg, vectorRead); + bvm.map(shapedArg, vectorRead); } auto hooks = llvm::to_vector<4>(customVectorizationHooks); @@ -428,15 +560,35 @@ : VectorType::get(outShape, op->getResult(0).getType()); auto zero = builder.create(loc, vType, builder.getZeroAttr(vType)); + // Indexing maps at the time of vector.transfer_read are adjusted to order + // vector dimensions in the same order as the canonical linalg op iteration + // space order. + // The indexings for the contraction therefore need to be adjusted. + // TODO: consider dropping contraction special casing altogether, this will + // require more advanced canonicalizations involving vector.multi_reduction + // that are not yet available. + SmallVector indexingMaps{ + inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(0))) + .compose(linalgOp.getIndexingMap(0)), + inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(1))) + .compose(linalgOp.getIndexingMap(1)), + inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(2))) + .compose(linalgOp.getIndexingMap(2))}; Operation *contract = builder.create( loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero, - linalgOp.indexing_maps(), linalgOp.iterator_types()); + builder.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types()); return VectorizationResult{VectorizationStatus::NewOp, contract}; }; return vectorizeAsLinalgGeneric(builder, linalgOp, newResults, + /*broadcastToMaximalCommonShape=*/false, {vectorizeContraction}); } +static bool allIndexingsAreProjectedPermutation(LinalgOp op) { + return llvm::all_of(op.getIndexingMaps(), + [](AffineMap m) { return m.isProjectedPermutation(); }); +} + LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { auto linalgOp = cast(op); // All types must be static shape to go to vector. @@ -448,7 +600,16 @@ return failure(); if (isElementwise(op)) return success(); - return success(isaContractionOpInterface(linalgOp)); + if (isaContractionOpInterface(linalgOp)) + return success(); + // TODO: the common vector shape is equal to the static loop sizes only when + // all indexing maps are projected permutations. For convs and stencils the + // logic will need to evolve. + // TODO: probably need some extra checks for reduction followed by consumer + // ops that may not commute (e.g. linear reduction + non-linear instructions). + if (allIndexingsAreProjectedPermutation(linalgOp)) + return success(); + return failure(); } LogicalResult @@ -458,13 +619,17 @@ return failure(); edsc::ScopedContext scope(builder, op->getLoc()); - if (isElementwise(op)) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - << "Vectorize linalg op as a generic: " << *op); - return vectorizeAsLinalgGeneric(builder, cast(op), newResults); - } + auto linalgOp = cast(op); - return vectorizeContraction(builder, cast(op), newResults); + if (isaContractionOpInterface(linalgOp)) + return vectorizeContraction(builder, linalgOp, newResults); + + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " + << "Vectorize linalg op as a generic by broadcasting to " + "maximal common shape: " + << *op); + return vectorizeAsLinalgGeneric(builder, linalgOp, newResults, + /*broadcastToMaximalCommonShape=*/true); } //----------------------------------------------------------------------------// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -231,6 +231,45 @@ return builder.getI64ArrayAttr(values); } +//===----------------------------------------------------------------------===// +// MultiDimReductionOp +//===----------------------------------------------------------------------===// + +void vector::MultiDimReductionOp::build(OpBuilder &builder, + OperationState &result, Value source, + ArrayRef reductionMask, + CombiningKind kind) { + result.addOperands(source); + auto sourceVectorType = source.getType().cast(); + auto targetShape = MultiDimReductionOp::inferDestShape( + sourceVectorType.getShape(), reductionMask); + auto targetVectorType = + VectorType::get(targetShape, sourceVectorType.getElementType()); + result.addTypes(targetVectorType); + + SmallVector reductionDims; + for (auto en : llvm::enumerate(reductionMask)) + if (en.value()) + reductionDims.push_back(en.index()); + result.addAttribute(getReductionDimsAttrName(), + builder.getI64ArrayAttr(reductionDims)); + result.addAttribute(getKindAttrName(), + CombiningKindAttr::get(kind, builder.getContext())); +} + +static LogicalResult verify(MultiDimReductionOp op) { + auto reductionMask = op.getReductionMask(); + auto targetShape = MultiDimReductionOp::inferDestShape( + op.getSourceVectorType().getShape(), reductionMask); + auto targetVectorType = + VectorType::get(targetShape, op.getSourceVectorType().getElementType()); + if (targetVectorType != op.getDestVectorType()) + return op.emitError("invalid output vector type: ") + << op.getDestVectorType() << " (expected: " << targetVectorType + << ")"; + return success(); +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// @@ -2160,6 +2199,29 @@ // TransferReadOp //===----------------------------------------------------------------------===// +AffineMap TransferReadOp::insertBroadcasts(AffineMap map, VectorType vt, + ArrayRef targetShape) { + unsigned targetRank = targetShape.size(); + assert(vt.getShape().size() <= targetRank && "mismatching ranks"); + if (vt.getShape().size() == targetRank) + return map; + MLIRContext *ctx = map.getContext(); + SmallVector exprs; + exprs.reserve(targetRank); + for (unsigned idx = 0, vtidx = 0; idx < targetRank; ++idx) { + // If shapes match, just keep the existing indexing and advance ranks. + if (vtidx < vt.getShape().size() && + vt.getShape()[vtidx] == targetShape[idx]) { + exprs.push_back(map.getResult(vtidx)); + ++vtidx; + continue; + } + // Otherwise insert a broadcast. + exprs.push_back(getAffineConstantExpr(0, ctx)); + } + return AffineMap::get(map.getNumDims(), /*numSymbols=*/0, exprs, ctx); +} + template static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1752,18 +1752,23 @@ /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul /// semantics to: /// ``` -/// %flattened_a = vector.shape_cast %a -/// %flattened_b = vector.shape_cast %b +/// %mta = maybe_transpose +/// %mtb = maybe_transpose +/// %flattened_a = vector.shape_cast %mta +/// %flattened_b = vector.shape_cast %mtb /// %flattened_d = vector.matmul %flattened_a, %flattened_b -/// %d = vector.shape_cast %%flattened_d +/// %mtd = vector.shape_cast %flattened_d +/// %d = maybe_untranspose %mtd /// %e = add %c, %d /// ``` /// `vector.matmul` later lowers to `llvm.matrix.multiply`. // -/// This only kicks in when VectorTransformsOptions is set to OuterProduct and -/// the vector.contract op is a row-major matrix multiply. -LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite( - vector::ContractionOp op, PatternRewriter &rewriter) const { +/// This only kicks in when VectorTransformsOptions is set to `Matmul`. +/// vector.transpose operations are inserted if the vector.contract op is not a +/// row-major matrix multiply. +LogicalResult +ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rew) const { // TODO: implement masks if (llvm::size(op.masks()) != 0) return failure(); @@ -1779,37 +1784,67 @@ !isReductionIterator(iteratorTypes[2])) return failure(); - if (!isRowMajorMatmul(op.indexing_maps())) - return failure(); - Type elementType = op.getLhsType().getElementType(); if (!elementType.isIntOrFloat()) return failure(); - VectorType lhsType = op.getLhsType(); - VectorType rhsType = op.getRhsType(); + // Perform lhs + rhs transpositions to conform to matmul row-major semantics. + // Bail out if the contraction cannot be put in this form. + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + AffineExpr m, n, k; + bindDims(rew.getContext(), m, n, k); + // LHS must be A(m, k) or A(k, m). + Value lhs = op.lhs(); + auto lhsMap = op.indexing_maps()[0].cast().getValue(); + if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) + lhs = rew.create(loc, lhs, ArrayRef{1, 0}); + else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) + return failure(); + + // RHS must be B(k, n) or B(n, k). + Value rhs = op.rhs(); + auto rhsMap = op.indexing_maps()[1].cast().getValue(); + if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) + rhs = rew.create(loc, rhs, ArrayRef{1, 0}); + else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) + return failure(); + + // At this point lhs and rhs are in row-major. + VectorType lhsType = lhs.getType().cast(); + VectorType rhsType = rhs.getType().cast(); int64_t lhsRows = lhsType.getDimSize(0); int64_t lhsColumns = lhsType.getDimSize(1); int64_t rhsColumns = rhsType.getDimSize(1); Type flattenedLHSType = VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + lhs = rew.create(loc, flattenedLHSType, lhs); + Type flattenedRHSType = VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - auto lhs = rewriter.create(op.getLoc(), flattenedLHSType, - op.lhs()); - auto rhs = rewriter.create(op.getLoc(), flattenedRHSType, - op.rhs()); - - Value mul = rewriter.create(op.getLoc(), lhs, rhs, lhsRows, - lhsColumns, rhsColumns); - mul = rewriter.create(op.getLoc(), op.acc().getType(), - mul); - if (elementType.isa()) - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - else - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - + rhs = rew.create(loc, flattenedRHSType, rhs); + + Value mul = rew.create(loc, lhs, rhs, lhsRows, lhsColumns, + rhsColumns); + mul = rew.create( + loc, + VectorType::get({lhsRows, rhsColumns}, + getElementTypeOrSelf(op.acc().getType())), + mul); + + // ACC must be C(m, n) or C(n, m). + auto accMap = op.indexing_maps()[2].cast().getValue(); + if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) + mul = rew.create(loc, mul, ArrayRef{1, 0}); + else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) + llvm_unreachable("invalid contraction semantics"); + + Value res = elementType.isa() + ? static_cast(rew.create(loc, op.acc(), mul)) + : static_cast(rew.create(loc, op.acc(), mul)); + + rew.replaceOp(op, res); return success(); } diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -22,6 +22,6 @@ // // CHECK: vector.contract // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] -// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32> +// CHECK-SAME: : vector<8x16xf32>, vector<12x16xf32> into vector<8x12xf32> // // CHECK: linalg.copy diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -58,18 +58,19 @@ iterator_types = ["parallel", "parallel", "reduction"] } +// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> // CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @vectorization_test func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> - // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> - // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]] - // CHECK-SAME: vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> + // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]] + // CHECK-SAME: vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> linalg.generic #matmul_trait ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) @@ -95,18 +96,19 @@ iterator_types = ["parallel", "parallel", "reduction"] } +// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> // CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func @vectorization_test_integer func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, %C: memref<8x32xi32>) { // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> - // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<32x16xi32> // CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> - // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], - // CHECK-SAME: vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> + // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]], + // CHECK-SAME: vector<8x16xi32>, vector<32x16xi32> into vector<8x32xi32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> linalg.generic #matmul_trait ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>) @@ -252,13 +254,12 @@ memref<4x256xf32>, memref<4x256xf32>) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> - // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32> + // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<4x256xf32> // CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32, %arg14 : f32): - // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> - // CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> + // CHECK: %[[ADD:.*]] = addf %[[V0]], %[[V1]] : vector<4x256xf32> %6 = addf %arg4, %arg6 : f32 // CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> %7 = cmpf ogt, %arg3, %arg6 : f32 @@ -274,8 +275,7 @@ %12 = math.rsqrt %arg5 : f32 // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> %13 = select %7, %arg5, %arg6 : f32 - // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> - // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> + // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0]] : vector<4x256xf32> %14 = subf %arg5, %arg4 : f32 // CHECK: %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32> %15 = math.tanh %arg5 : f32 @@ -334,11 +334,10 @@ // CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> - // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32> + // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<4x256xf32> // CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> - // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> - // CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32> + // CHECK: %[[ADD:.*]] = addf %[[V0]], %[[V1]] : vector<4x256xf32> %6 = addf %arg4, %arg6 : f32 // CHECK: %[[CMP:.*]] = cmpf ogt, %[[V2]], %[[V1]] : vector<4x256xf32> %7 = cmpf ogt, %arg3, %arg6 : f32 @@ -354,8 +353,7 @@ %12 = math.rsqrt %arg5 : f32 // CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32> %13 = select %7, %arg5, %arg6 : f32 - // CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32> - // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32> + // CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0]] : vector<4x256xf32> %14 = subf %arg5, %arg4 : f32 // CHECK: %[[TAN:.*]] = math.tanh %[[V3]] : vector<4x256xf32> %15 = math.tanh %arg5 : f32 @@ -428,12 +426,15 @@ // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0.000000e+00> : vector<8x12xf32> // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> - // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<4x12xf32> + // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<12x4xf32> // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> // // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. // a later canonicalization fuses the add into vector.contract. - // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32> + // CHECK: %[[C:.*]] = vector.contract + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + // CHECK-SAME: %[[V0]], %[[V1]], %[[VEC_C0]] : + // CHECK-SAME: vector<8x4xf32>, vector<12x4xf32> into vector<8x12xf32> // CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32> // CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32> %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>) @@ -453,15 +454,17 @@ // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32> // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8> - // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8> + // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<12x6xi8> // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32> // CHECK-DAG: %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32> - // CHECK-DAG: %[[V1_32:.*]] = sexti %[[V1]] : vector<6x12xi8> to vector<6x12xi32> + // CHECK-DAG: %[[V1_32:.*]] = sexti %[[V1]] : vector<12x6xi8> to vector<12x6xi32> // // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. // a later canonicalization fuses the add into vector.contract. - // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[V0_32]], %[[V1_32]], %[[VEC_C0]] - // CHECK-SAME: vector<4x6xi32>, vector<6x12xi32> into vector<4x12xi32> + // CHECK: %[[C:.*]] = vector.contract + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + // CHECK-SAME: %[[V0_32]], %[[V1_32]], %[[VEC_C0]] + // CHECK-SAME: vector<4x6xi32>, vector<12x6xi32> into vector<4x12xi32> // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32> // CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} // CHECK-SAME: vector<4x12xi32>, memref<4x12xi32> @@ -491,6 +494,8 @@ return %0 : tensor<2x3x4xf32> } +// ----- + // CHECK-LABEL: func @pad_static_high_padding // CHECK: linalg.pad_tensor func @pad_static_high_padding(%arg0: tensor, %pad_value: f32) -> tensor<2x3x4xf32> { @@ -501,6 +506,8 @@ return %0 : tensor<2x3x4xf32> } +// ----- + // CHECK-LABEL: func @pad_dynamic // CHECK: linalg.pad_tensor func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index, @@ -511,3 +518,72 @@ } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> return %0 : tensor<6x?x?x?xf32> } + +// ----- + +// CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)> + +// CHECK-LABEL: func @sum_exp +func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>) + -> tensor<4x16xf32> +{ + // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32> + // CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$M0]]} : tensor<4x16xf32>, vector<4x16x8xf32> + // CHECK: math.exp {{.*}} : vector<4x16x8xf32> + // CHECK: addf {{.*}} : vector<4x16x8xf32> + // CHECK: vector.multi_reduction #vector.kind, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32> + // CHECK: return {{.*}} : tensor<4x16xf32> + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%input : tensor<4x16x8xf32>) outs(%output : tensor<4x16xf32>) { + ^bb0(%arg0: f32, %arg1: f32): // no predecessors + %1 = math.exp %arg0 : f32 + %2 = addf %1, %arg1 : f32 + linalg.yield %2 : f32 + } -> tensor<4x16xf32> + return %0 : tensor<4x16xf32> +} + +// ----- + +// CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)> +// CHECK-DAG: #[[$M2:.*]] = affine_map<(d0, d1) -> (0, 0, d1, d0)> +// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1) -> (d1, 0, 0, d0)> +// CHECK-DAG: #[[$M4:.*]] = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: func @sum_exp_2 +func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: tensor<5x2xf32>) + -> tensor<5x2xf32> +{ + // CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32> + // CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32> + // CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x3x4x5xf32> + // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> + // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> + // CHECK: addf {{.*}} : vector<2x3x4x5xf32> + // CHECK: addf {{.*}} : vector<2x3x4x5xf32> + // CHECK: vector.multi_reduction #vector.kind, {{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> + // CHECK: vector.transfer_write {{.*}} {permutation_map = #[[$M4]]} : vector<2x5xf32>, tensor<5x2xf32> + // CHECK: return {{.*}} : tensor<5x2xf32> + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d0)>, + affine_map<(d0, d1, d2, d3) -> (d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d3, d0)> + ], + iterator_types = ["parallel", "reduction", "reduction", "parallel"] + } ins(%input, %input_2 : tensor<3x2xf32>, tensor<5x4xf32>) outs(%output : tensor<5x2xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): // no predecessors + %1 = math.exp %arg0 : f32 + %2 = math.exp %arg1 : f32 + %3 = addf %1, %2 : f32 + %4 = addf %3, %arg2 : f32 + linalg.yield %4 : f32 + } -> tensor<5x2xf32> + return %0 : tensor<5x2xf32> +}