diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -78,6 +78,20 @@ static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context); + /// Returns an identity affine map witn `numDims` input dimensions and + /// filtered results using `keepDimFilter`. If `keepDimFilter` returns true + /// for a dimension, the dimension is kept in the affine map results. + /// Otherwise, the dimension is dropped from the results. + /// + /// Examples: + /// * getFilteredIdentityMap(4, [false, true, false, true]) + /// -> affine_map<(d0, d1, d2, d3) -> (d1, d3)> + /// * getFilteredIdentityMap(3, [false, false, true]) + /// -> affine_map<(d0, d1, d2) -> (d2)> + static AffineMap + getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, + llvm::function_ref keepDimFilter); + /// Returns an AffineMap representing a permutation. /// The permutation is expressed as a non-empty vector of integers. /// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -605,8 +605,18 @@ Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand); + + // Compute the vector type of the value to store. This type should be an + // identity or projection of the canonical vector type without any permutation + // applied, given that any permutation in a transfer write happens as part of + // the write itself. + AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap( + opOperandMap.getContext(), opOperandMap.getNumInputs(), + [&](AffineDimExpr dimExpr) -> bool { + return llvm::is_contained(opOperandMap.getResults(), dimExpr); + }); auto vectorType = state.getCanonicalVecType( - getElementTypeOrSelf(outputOperand->get().getType()), opOperandMap); + getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap); Operation *write; if (vectorType.getRank() > 0) { @@ -614,13 +624,14 @@ SmallVector indices(linalgOp.getRank(outputOperand), rewriter.create(loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType); + assert(value.getType() == vectorType && "Incorrect type"); write = rewriter.create( loc, value, outputOperand->get(), indices, writeMap); } else { // 0-d case is still special: do not invert the reindexing writeMap. if (!isa(value.getType())) value = rewriter.create(loc, vectorType, value); - assert(value.getType() == vectorType && "incorrect type"); + assert(value.getType() == vectorType && "Incorrect type"); write = rewriter.create( loc, value, outputOperand->get(), ValueRange{}); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -113,6 +113,19 @@ return AffineMap::get(dims, 0, id.getResults().take_back(results), context); } +AffineMap AffineMap::getFilteredIdentityMap( + MLIRContext *ctx, unsigned numDims, + llvm::function_ref keepDimFilter) { + auto identityMap = getMultiDimIdentityMap(numDims, ctx); + + // Apply filter to results. + llvm::SmallBitVector dropDimResults(numDims); + for (auto [idx, resultExpr] : llvm::enumerate(identityMap.getResults())) + dropDimResults[idx] = !keepDimFilter(resultExpr.cast()); + + return identityMap.dropResults(dropDimResults); +} + bool AffineMap::isMinorIdentity() const { return getNumDims() >= getNumResults() && *this == diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1751,3 +1751,38 @@ // CHECK: vector.broadcast %{{.*}} : f32 to vector // CHECK: vector.transfer_write {{.*}} : vector, tensor +// ----- + +// Make sure we generate the right transfer writes for multi-output generic ops +// with different permutation maps. + +func.func @multi_output_generic_different_perm_maps(%in0: tensor<4x1xf32>, + %out0: tensor<4x1xf32>, + %out1: tensor<1x4xf32>) -> (tensor<4x1xf32>, tensor<1x4xf32>) { + %13:2 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)> ], + iterator_types = ["parallel", "parallel"]} + ins(%in0 : tensor<4x1xf32>) + outs(%out0, %out1 : tensor<4x1xf32>, tensor<1x4xf32>) { + ^bb0(%in: f32, %out: f32, %out_2: f32): + %16 = arith.addf %in, %in : f32 + linalg.yield %16, %16 : f32, f32 + } -> (tensor<4x1xf32>, tensor<1x4xf32>) + return %13#0, %13#1 : tensor<4x1xf32>, tensor<1x4xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op +} + +// CHECK-LABEL: func @multi_output_generic_different_perm_maps +// CHECK: %[[VAL_5:.*]] = vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32> +// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[VAL_5]] : vector<4x1xf32> +// CHECK: %[[VAL_7:.*]] = vector.transpose %[[VAL_6]], [1, 0] : vector<4x1xf32> to vector<1x4xf32> +// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 0] : vector<1x4xf32> to vector<4x1xf32> +// CHECK: vector.transfer_write %[[VAL_8]], %{{.*}} {in_bounds = [true, true]} : vector<4x1xf32>, tensor<4x1xf32> +// CHECK: vector.transfer_write %[[VAL_7]], %{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>