diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -323,6 +323,13 @@ /// Returns true if the AffineMap represents a symbol-less permutation map. bool isPermutation() const; + /// Given a projected permutation map, returns the projection map without the + /// permutation. + /// Example: + /// map : affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)> + /// result : affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> + AffineMap getNonPermutedProjectionMap(); + /// Returns the map consisting of the `resultPos` subset. AffineMap getSubMap(ArrayRef resultPos) const; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -605,8 +605,14 @@ Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand); + + // Compute the vector type of the value to store. This type should be an + // identity or projection of the canonical vector type without any permutation + // applied, given that any permutation in a transfer write happens as part of + // the write itself. + AffineMap vectorTypeMap = opOperandMap.getNonPermutedProjectionMap(); auto vectorType = state.getCanonicalVecType( - getElementTypeOrSelf(outputOperand->get().getType()), opOperandMap); + getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap); Operation *write; if (vectorType.getRank() > 0) { @@ -614,13 +620,14 @@ SmallVector indices(linalgOp.getRank(outputOperand), rewriter.create(loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType); + assert(value.getType() == vectorType && "Incorrect type"); write = rewriter.create( loc, value, outputOperand->get(), indices, writeMap); } else { // 0-d case is still special: do not invert the reindexing writeMap. if (!isa(value.getType())) value = rewriter.create(loc, vectorType, value); - assert(value.getType() == vectorType && "incorrect type"); + assert(value.getType() == vectorType && "Incorrect type"); write = rewriter.create( loc, value, outputOperand->get(), ValueRange{}); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -554,6 +554,20 @@ return isProjectedPermutation(); } +AffineMap AffineMap::getNonPermutedProjectionMap() { + assert(isProjectedPermutation() && "Expected projected permutation"); + + // Create an identity map with the same number of inputs and project the + // dimensions needed. + AffineMap projectionMap = + getMultiDimIdentityMap(getNumInputs(), getContext()); + llvm::SmallBitVector projectedDims(getNumInputs(), true); + for (int i = 0, numResults = getNumResults(); i < numResults; ++i) + projectedDims[getDimPosition(i)] = false; + + return projectionMap.dropResults(projectedDims); +} + AffineMap AffineMap::getSubMap(ArrayRef resultPos) const { SmallVector exprs; exprs.reserve(resultPos.size()); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1751,3 +1751,38 @@ // CHECK: vector.broadcast %{{.*}} : f32 to vector // CHECK: vector.transfer_write {{.*}} : vector, tensor +// ----- + +// Make sure we generate the right transfer writes for multi-output generic ops +// with different permutation maps. + +func.func @multi_output_generic_different_perm_maps(%in0: tensor<4x1xf32>, + %out0: tensor<4x1xf32>, + %out1: tensor<1x4xf32>) -> (tensor<4x1xf32>, tensor<1x4xf32>) { + %13:2 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)> ], + iterator_types = ["parallel", "parallel"]} + ins(%in0 : tensor<4x1xf32>) + outs(%out0, %out1 : tensor<4x1xf32>, tensor<1x4xf32>) { + ^bb0(%in: f32, %out: f32, %out_2: f32): + %16 = arith.addf %in, %in : f32 + linalg.yield %16, %16 : f32, f32 + } -> (tensor<4x1xf32>, tensor<1x4xf32>) + return %13#0, %13#1 : tensor<4x1xf32>, tensor<1x4xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op +} + +// CHECK-LABEL: func @multi_output_generic_different_perm_maps +// CHECK: %[[VAL_5:.*]] = vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32> +// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[VAL_5]] : vector<4x1xf32> +// CHECK: %[[VAL_7:.*]] = vector.transpose %[[VAL_6]], [1, 0] : vector<4x1xf32> to vector<1x4xf32> +// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 0] : vector<1x4xf32> to vector<4x1xf32> +// CHECK: vector.transfer_write %[[VAL_8]], %{{.*}} {in_bounds = [true, true]} : vector<4x1xf32>, tensor<4x1xf32> +// CHECK: vector.transfer_write %[[VAL_7]], %{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>