diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -503,6 +503,20 @@ getProjectedMap(AffineMap map, const llvm::SmallDenseSet &projectedDimensions); +/// Apply a permutation from `map` to `source` and return the result. +template +SmallVector applyPermuationMap(AffineMap map, llvm::ArrayRef source) { + assert(map.isProjectedPermutation()); + assert(map.getNumInputs() == source.size()); + SmallVector result; + result.reserve(map.getNumResults()); + for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { + unsigned dim = map.getDimPosition(i); + result.push_back(source[dim]); + } + return result; +} + inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { map.print(os); return os; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -191,7 +191,6 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType, Value value, OpOperand *outputOperand) { auto linalgOp = cast(outputOperand->getOwner()); - assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand)); auto vecType = value.getType().dyn_cast(); if (!vecType || vecType.getShape() == targetVectorType.getShape()) return value; @@ -245,6 +244,9 @@ auto linalgOp = cast(outputOperand->getOwner()); AffineMap map = reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)); + SmallVector transposeShape = + applyPermuationMap(inversePermutation(map), vectorType.getShape()); + vectorType = VectorType::get(transposeShape, vectorType.getElementType()); SmallVector indices(linalgOp.getRank(outputOperand), b.create(loc, 0)); value = broadcastIfNeeded(b, value, vectorType.getShape()); @@ -569,9 +571,16 @@ return VectorizationResult{VectorizationStatus::Failure, nullptr}; ArrayRef outShape = linalgOp.getShape(linalgOp.getOutputOperand(0)); - auto vType = outShape.empty() - ? op->getResult(0).getType() - : VectorType::get(outShape, op->getResult(0).getType()); + Type vType; + if (outShape.empty()) { + vType = op->getResult(0).getType(); + } else { + SmallVector resultShape = applyPermuationMap( + inversePermutation(reindexIndexingMap( + linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0)))), + outShape); + vType = VectorType::get(resultShape, op->getResult(0).getType()); + } auto zero = b.create(loc, vType, b.getZeroAttr(vType)); // Indexing maps at the time of vector.transfer_read are adjusted to order // vector dimensions in the same order as the canonical linalg op iteration diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -342,17 +342,6 @@ vector::UnrollVectorOptions options; }; -template -SmallVector permute(AffineMap map, llvm::ArrayRef source) { - SmallVector result; - result.reserve(map.getNumResults()); - for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { - unsigned dim = map.getDimPosition(i); - result.push_back(source[dim]); - } - return result; -} - struct UnrollContractionPattern : public OpRewritePattern { struct OffsetMapInfo { @@ -403,7 +392,7 @@ AffineMap permutationMap, ArrayRef operandOffets) { SmallVector operandShape = - permute(permutationMap, ArrayRef(*targetShape)); + applyPermuationMap(permutationMap, ArrayRef(*targetShape)); SmallVector operandStrides(operandOffets.size(), 1); slicesOperands[index] = rewriter.create( loc, operand, operandOffets, operandShape, operandStrides); @@ -412,7 +401,7 @@ // Extract the new lhs operand. AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0]; SmallVector lhsOffets = - permute(lhsPermutationMap, ArrayRef(offsets)); + applyPermuationMap(lhsPermutationMap, ArrayRef(offsets)); extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets); // If there is a mask associated to lhs, extract it as well. if (slicesOperands.size() > 3) @@ -421,7 +410,7 @@ // Extract the new rhs operand. AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1]; SmallVector rhsOffets = - permute(rhsPermutationMap, ArrayRef(offsets)); + applyPermuationMap(rhsPermutationMap, ArrayRef(offsets)); extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets); // If there is a mask associated to rhs, extract it as well. if (slicesOperands.size() > 4) @@ -429,7 +418,7 @@ AffineMap accPermutationMap = contractOp.getIndexingMaps()[2]; SmallVector accOffets = - permute(accPermutationMap, ArrayRef(offsets)); + applyPermuationMap(accPermutationMap, ArrayRef(offsets)); // If a version of the accumulator has already been computed, use it // otherwise extract the first version from the original operand. auto accIt = accCache.find(accOffets); @@ -439,13 +428,13 @@ extractOperand(2, contractOp.acc(), accPermutationMap, accOffets); SmallVector dstShape = - permute(dstAffineMap, ArrayRef(*targetShape)); + applyPermuationMap(dstAffineMap, ArrayRef(*targetShape)); auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, contractOp, slicesOperands, targetType); SmallVector dstOffets = - permute(dstAffineMap, ArrayRef(offsets)); + applyPermuationMap(dstAffineMap, ArrayRef(offsets)); // Save the accumulated value untill all the loops are unrolled since // reduction loop keep updating the accumulator. accCache[dstOffets] = newOp->getResult(0); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -85,6 +85,44 @@ // ----- +#matmul_transpose_out_trait = { + args_in = 2, + args_out = 1, + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> + ], + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-DAG: #[[$trans_2d:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$nk:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @generic_output_transpose +func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>, + %C: memref<32x8xf32>) { + // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32> + // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$nk]], #[[$mn]]] + // CHECK-SAME: vector<8x16xf32>, vector<32x16xf32> into vector<8x32xf32> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<32x8xf32> + linalg.generic #matmul_transpose_out_trait + ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>) + outs(%C : memref<32x8xf32>) { + ^bb(%a: f32, %b: f32, %c: f32) : + %d = mulf %a, %b: f32 + %e = addf %c, %d: f32 + linalg.yield %e : f32 + } + return +} + +// ----- + #matmul_trait = { args_in = 2, args_out = 1,