diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1383,18 +1383,6 @@ "ArrayAttr":$inBounds)> ]; - let extraClassDeclaration = [{ - /// Return a new `result` map with `0` inserted in the proper positions so - /// that vector.transfer_read `result` produces a vector of same element - /// type as `vt` and shape `targetShape. - /// Assume that `map` is a permutation map for a vector.transfer_read op, - /// `vt` the vector type produced by the vector.transfer_read and - /// `targetShape` is the desired `targetShape` for a broadcast version of - /// `vt`. - static AffineMap insertBroadcasts(AffineMap map, VectorType vt, - ArrayRef targetShape); - }]; - let hasFolder = 1; } diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -404,6 +404,48 @@ /// ``` AffineMap inversePermutation(AffineMap map); +/// Return the reverse map of a projected permutation where the projected +/// dimensions are transformed into 0s. +/// +/// Prerequisites: `map` must be a projected permuation. +/// +/// Example 1: +/// +/// ```mlir +/// affine_map<(d0, d1, d2, d3) -> (d2, d0)> +/// ``` +/// +/// returns: +/// +/// ```mlir +/// affine_map<(d0, d1) -> (d1, 0, d0, 0)> +/// ``` +/// +/// Example 2: +/// +/// ```mlir +/// affine_map<(d0, d1, d2, d3) -> (d0, d3)> +/// ``` +/// +/// returns: +/// +/// ```mlir +/// affine_map<(d0, d1) -> (d0, 0, 0, d1)> +/// ``` +/// +/// Example 3: +/// +/// ```mlir +/// affine_map<(d0, d1, d2, d3) -> (d2)> +/// ``` +/// +/// returns: +/// +/// ```mlir +/// affine_map<(d0) -> (0, 0, d0, 0)> +/// ``` +AffineMap inverseAndBroadcastProjectedPermuation(AffineMap map); + /// Concatenates a list of `maps` into a single AffineMap, stepping over /// potentially empty maps. Assumes each of the underlying map has 0 symbols. /// The resulting map has a number of dims equal to the max of `maps`' dims and diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -493,15 +493,18 @@ bvm.map(shapedArg, loaded); continue; } - AffineMap map = inversePermutation( - reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber()))); - VectorType vectorType = VectorType::get(map.compose(shapedType.getShape()), - shapedType.getElementType()); + AffineMap map; + VectorType vectorType; if (broadcastToMaximalCommonShape) { - map = vector::TransferReadOp::insertBroadcasts(map, vectorType, - commonVectorShape); + map = inverseAndBroadcastProjectedPermuation( + linalgOp.getIndexingMap(bbarg.getArgNumber())); vectorType = - VectorType::get(commonVectorShape, vectorType.getElementType()); + VectorType::get(commonVectorShape, shapedType.getElementType()); + } else { + map = inversePermutation( + reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber()))); + vectorType = VectorType::get(map.compose(shapedType.getShape()), + shapedType.getElementType()); } Value vectorRead = buildVectorRead(builder, shapedArg, vectorType, map); LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2253,29 +2253,6 @@ // TransferReadOp //===----------------------------------------------------------------------===// -AffineMap TransferReadOp::insertBroadcasts(AffineMap map, VectorType vt, - ArrayRef targetShape) { - unsigned targetRank = targetShape.size(); - assert(vt.getShape().size() <= targetRank && "mismatching ranks"); - if (vt.getShape().size() == targetRank) - return map; - MLIRContext *ctx = map.getContext(); - SmallVector exprs; - exprs.reserve(targetRank); - for (unsigned idx = 0, vtidx = 0; idx < targetRank; ++idx) { - // If shapes match, just keep the existing indexing and advance ranks. - if (vtidx < vt.getShape().size() && - vt.getShape()[vtidx] == targetShape[idx]) { - exprs.push_back(map.getResult(vtidx)); - ++vtidx; - continue; - } - // Otherwise insert a broadcast. - exprs.push_back(getAffineConstantExpr(0, ctx)); - } - return AffineMap::get(map.getNumDims(), /*numSymbols=*/0, exprs, ctx); -} - template static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -664,6 +664,19 @@ return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext()); } +AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) { + assert(map.isProjectedPermutation()); + MLIRContext *context = map.getContext(); + AffineExpr zero = mlir::getAffineConstantExpr(0, context); + // Start with all the results as 0. + SmallVector exprs(map.getNumInputs(), zero); + for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { + // Reverse each dimension existing in the oringal map result. + exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context); + } + return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context); +} + AffineMap mlir::concatAffineMaps(ArrayRef maps) { unsigned numResults = 0, numDims = 0, numSymbols = 0; for (auto m : maps) diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -381,6 +381,43 @@ // ----- +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0, 0, 0, 0)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, 0, d0, 0)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1, 0, d0, 0)> +// CHECK: func @generic_vectorize_broadcast_transpose +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[CF:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[V0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {permutation_map = #[[$MAP0]]} : memref<4x4xf32>, vector<4x4x4x4xf32> +// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {permutation_map = #[[$MAP1]]} : memref<4xf32>, vector<4x4x4x4xf32> +// CHECK: %[[V2:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {permutation_map = #[[$MAP2]]} : memref<4xf32>, vector<4x4x4x4xf32> +// CHECK: %[[V3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {permutation_map = #[[$MAP3]]} : memref<4x4xf32>, vector<4x4x4x4xf32> +// CHECK: %[[SUB:.*]] = subf %[[V0]], %[[V1]] : vector<4x4x4x4xf32> +// CHECK: %[[ADD0:.*]] = addf %[[V2]], %[[SUB]] : vector<4x4x4x4xf32> +// CHECK: %[[ADD1:.*]] = addf %[[V3]], %[[ADD0]] : vector<4x4x4x4xf32> +// CHECK: vector.transfer_write %[[ADD1]], {{.*}} : vector<4x4x4x4xf32>, memref<4x4x4x4xf32> +func @generic_vectorize_broadcast_transpose( + %A: memref<4xf32>, %B: memref<4x4xf32>, %C: memref<4x4x4x4xf32>) { + linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d2)>, + affine_map<(d0, d1, d2, d3) -> (d2, d0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%B, %A, %A, %B: memref<4x4xf32>, memref<4xf32>, memref<4xf32>, memref<4x4xf32>) + outs(%C : memref<4x4x4x4xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + %s = subf %arg0, %arg1 : f32 + %a = addf %arg2, %s : f32 + %b = addf %arg3, %a : f32 + linalg.yield %b : f32 + } + return +} + +// ----- + // Test different input maps. #matmul_trait = { indexing_maps = [