Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -473,7 +473,7 @@ Operation *write; if (vectorType.getRank() > 0) { - AffineMap writeMap = reindexIndexingMap(opOperandMap); + AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); SmallVector indices(linalgOp.getRank(outputOperand), rewriter.create(loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType.getShape()); Index: mlir/lib/Dialect/Vector/IR/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3361,11 +3361,11 @@ p << " : " << getShapedType() << ", " << getVectorType(); } -/// Infers the mask type for a transfer read given its vector type and -/// permutation map. The mask in a transfer read operation applies to the -/// tensor/buffer reading part of it and its type should match the shape read +/// Infers the mask type for a transfer op given its vector type and +/// permutation map. The mask in a transfer op operation applies to the +/// tensor/buffer part of it and its type should match the vector shape /// *before* any permutation or broadcasting. -static VectorType inferTransferReadMaskType(VectorType vecType, +static VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap) { auto i1Type = IntegerType::get(permMap.getContext(), 1); AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap)); @@ -3424,7 +3424,7 @@ maskInfo.location, "does not support masks with vector element type"); // Instead of adding the mask type as an op type, compute it based on the // vector type and the permutation map (to keep the type signature small). - auto maskType = inferTransferReadMaskType(vectorType, permMap); + auto maskType = inferTransferOpMaskType(vectorType, permMap); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); } @@ -3443,7 +3443,7 @@ auto paddingType = getPadding().getType(); auto permutationMap = getPermutationMap(); VectorType inferredMaskType = - maskType ? inferTransferReadMaskType(vectorType, permutationMap) + maskType ? inferTransferOpMaskType(vectorType, permutationMap) : VectorType(); auto sourceElementType = shapedType.getElementType(); @@ -3483,7 +3483,7 @@ /// Returns the mask type expected by this operation. Mostly used for /// verification purposes. It requires the operation to be vectorized." Type TransferReadOp::getExpectedMaskType() { - return inferTransferReadMaskType(getVectorType(), getPermutationMap()); + return inferTransferOpMaskType(getVectorType(), getPermutationMap()); } template @@ -3824,18 +3824,6 @@ build(builder, result, vector, dest, indices, permutationMap, inBounds); } -/// Infers the mask type for a transfer write given its vector type and -/// permutation map. The mask in a transfer read operation applies to the -/// tensor/buffer writing part of it and its type should match the shape written -/// *after* any permutation. -static VectorType inferTransferWriteMaskType(VectorType vecType, - AffineMap permMap) { - auto i1Type = IntegerType::get(permMap.getContext(), 1); - SmallVector maskShape = - compressUnusedDims(permMap).compose(vecType.getShape()); - return VectorType::get(maskShape, i1Type); -} - ParseResult TransferWriteOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); @@ -3880,7 +3868,7 @@ if (shapedType.getElementType().dyn_cast()) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); - auto maskType = inferTransferWriteMaskType(vectorType, permMap); + auto maskType = inferTransferOpMaskType(vectorType, permMap); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); } @@ -3907,7 +3895,7 @@ VectorType maskType = getMaskType(); auto permutationMap = getPermutationMap(); VectorType inferredMaskType = - maskType ? inferTransferWriteMaskType(vectorType, permutationMap) + maskType ? inferTransferOpMaskType(vectorType, permutationMap) : VectorType(); if (llvm::size(getIndices()) != shapedType.getRank()) @@ -3933,7 +3921,7 @@ /// Returns the mask type expected by this operation. Mostly used for /// verification purposes. Type TransferWriteOp::getExpectedMaskType() { - return inferTransferWriteMaskType(getVectorType(), getPermutationMap()); + return inferTransferOpMaskType(getVectorType(), getPermutationMap()); } /// Fold: Index: mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp @@ -142,9 +142,10 @@ // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) // comp = (d0, d1, d2) -> (d2, d0, d1) auto comp = compressUnusedDims(map); + AffineMap permutationMap = inversePermutation(comp); // Get positions of remaining result dims. SmallVector indices; - llvm::transform(comp.getResults(), std::back_inserter(indices), + llvm::transform(permutationMap.getResults(), std::back_inserter(indices), [](AffineExpr expr) { return expr.dyn_cast().getPosition(); }); Index: mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -337,11 +337,11 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<8x14x16x7xi1> - %mask0 = vector.splat %m : vector<8x14x16x7xi1> + // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1> + %mask0 = vector.splat %m : vector<16x14x7x8xi1> %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor - // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32> - // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor + // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32> + // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<16x14x7x8xf32>, tensor vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32> Index: mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir =================================================================== --- mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir +++ mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s @@ -39,6 +39,34 @@ return %r : vector<32xf32> } +func.func @transfer_write_inbounds_3d(%A : memref<4x4x4xf32>) { + %c0 = arith.constant 0: index + %f = arith.constant 0.0 : f32 + %v0 = vector.splat %f : vector<2x3x4xf32> + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %f3 = arith.constant 3.0 : f32 + %f4 = arith.constant 4.0 : f32 + %f5 = arith.constant 5.0 : f32 + %f6 = arith.constant 6.0 : f32 + %f7 = arith.constant 7.0 : f32 + %f8 = arith.constant 8.0 : f32 + + %v1 = vector.insert %f1, %v0[0, 0, 0] : f32 into vector<2x3x4xf32> + %v2 = vector.insert %f2, %v1[0, 0, 3] : f32 into vector<2x3x4xf32> + %v3 = vector.insert %f3, %v2[0, 2, 0] : f32 into vector<2x3x4xf32> + %v4 = vector.insert %f4, %v3[0, 2, 3] : f32 into vector<2x3x4xf32> + %v5 = vector.insert %f5, %v4[1, 0, 0] : f32 into vector<2x3x4xf32> + %v6 = vector.insert %f6, %v5[1, 0, 3] : f32 into vector<2x3x4xf32> + %v7 = vector.insert %f7, %v6[1, 2, 0] : f32 into vector<2x3x4xf32> + %v8 = vector.insert %f8, %v7[1, 2, 3] : f32 into vector<2x3x4xf32> + vector.transfer_write %v8, %A[%c0, %c0, %c0] + {permutation_map = affine_map<(d0, d1, d2) -> (d2, d0, d1)>, + in_bounds = [true, true, true]} + : vector<2x3x4xf32>, memref<4x4x4xf32> + return +} + func.func @entry() { %c0 = arith.constant 0: index %c1 = arith.constant 1: index @@ -90,6 +118,24 @@ vector.print %6 : vector<32xf32> memref.dealloc %A : memref + + // 3D case + %c4 = arith.constant 4: index + %A1 = memref.alloc() {alignment=64} : memref<4x4x4xf32> + scf.for %i = %c0 to %c4 step %c1 { + scf.for %j = %c0 to %c4 step %c1 { + scf.for %k = %c0 to %c4 step %c1 { + %f = arith.constant 0.0: f32 + memref.store %f, %A1[%i, %j, %k] : memref<4x4x4xf32> + } + } + } + call @transfer_write_inbounds_3d(%A1) : (memref<4x4x4xf32>) -> () + %f = arith.constant 0.0: f32 + %r = vector.transfer_read %A1[%c0, %c0, %c0], %f + : memref<4x4x4xf32>, vector<4x4x4xf32> + vector.print %r : vector<4x4x4xf32> + return } @@ -100,3 +146,7 @@ // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13 ) + +// 3D case. +// CHECK: ( ( ( 1, 5, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 2, 6, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ), +// CHECK-SAME: ( ( 3, 7, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 4, 8, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ) )