Index: mlir/include/mlir/Dialect/Vector/IR/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1330,6 +1330,18 @@ memref, vector<32x256xf32> }}} + // or equivalently (rewrite with vector.transpose) + %f0 = arith.constant 0.0f : f32 + for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 step 256 { + affine.for %i2 = 0 to %2 step 32 { + %v0 = vector.transfer_read %A[%i0, %i1, %i2], (%f0) + {permutation_map: (d0, d1, d2) -> (d1, d2)} : + memref, vector<256x32xf32> + %v = vector.transpose %v0, [1, 0] : + vector<256x32xf32> to vector<32x256f32> + }}} + // Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into // vector<128xf32>. The underlying implementation will require a 1-D vector // broadcast: @@ -1485,6 +1497,19 @@ vector<16x32x64xf32>, memref }}}} + // or equivalently (rewrite with vector.transpose) + for %i0 = 0 to %0 { + affine.for %i1 = 0 to %1 step 32 { + affine.for %i2 = 0 to %2 step 64 { + affine.for %i3 = 0 to %3 step 16 { + %val = `ssa-value` : vector<16x32x64xf32> + %valt = vector.transpose %val, [1, 2, 0] : + vector<16x32x64xf32> -> vector<32x64x16xf32> + vector.transfer_write %valt, %A[%i0, %i1, %i2, %i3] + {permutation_map: (d0, d1, d2, d3) -> (d1, d2, d3)} : + vector<32x64x16xf32>, memref + }}}} + // write to a memref with vector element type. vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -473,7 +473,7 @@ Operation *write; if (vectorType.getRank() > 0) { - AffineMap writeMap = reindexIndexingMap(opOperandMap); + AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); SmallVector indices(linalgOp.getRank(outputOperand), rewriter.create(loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType.getShape()); Index: mlir/lib/Dialect/Vector/IR/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3361,12 +3361,12 @@ p << " : " << getShapedType() << ", " << getVectorType(); } -/// Infers the mask type for a transfer read given its vector type and -/// permutation map. The mask in a transfer read operation applies to the -/// tensor/buffer reading part of it and its type should match the shape read +/// Infers the mask type for a transfer op given its vector type and +/// permutation map. The mask in a transfer op operation applies to the +/// tensor/buffer part of it and its type should match the vector shape /// *before* any permutation or broadcasting. -static VectorType inferTransferReadMaskType(VectorType vecType, - AffineMap permMap) { +static VectorType inferTransferOpMaskType(VectorType vecType, + AffineMap permMap) { auto i1Type = IntegerType::get(permMap.getContext(), 1); AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap)); assert(invPermMap && "Inversed permutation map couldn't be computed"); @@ -3424,7 +3424,7 @@ maskInfo.location, "does not support masks with vector element type"); // Instead of adding the mask type as an op type, compute it based on the // vector type and the permutation map (to keep the type signature small). - auto maskType = inferTransferReadMaskType(vectorType, permMap); + auto maskType = inferTransferOpMaskType(vectorType, permMap); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); } @@ -3443,7 +3443,7 @@ auto paddingType = getPadding().getType(); auto permutationMap = getPermutationMap(); VectorType inferredMaskType = - maskType ? inferTransferReadMaskType(vectorType, permutationMap) + maskType ? inferTransferOpMaskType(vectorType, permutationMap) : VectorType(); auto sourceElementType = shapedType.getElementType(); @@ -3483,7 +3483,7 @@ /// Returns the mask type expected by this operation. Mostly used for /// verification purposes. It requires the operation to be vectorized." Type TransferReadOp::getExpectedMaskType() { - return inferTransferReadMaskType(getVectorType(), getPermutationMap()); + return inferTransferOpMaskType(getVectorType(), getPermutationMap()); } template @@ -3824,18 +3824,6 @@ build(builder, result, vector, dest, indices, permutationMap, inBounds); } -/// Infers the mask type for a transfer write given its vector type and -/// permutation map. The mask in a transfer read operation applies to the -/// tensor/buffer writing part of it and its type should match the shape written -/// *after* any permutation. -static VectorType inferTransferWriteMaskType(VectorType vecType, - AffineMap permMap) { - auto i1Type = IntegerType::get(permMap.getContext(), 1); - SmallVector maskShape = - compressUnusedDims(permMap).compose(vecType.getShape()); - return VectorType::get(maskShape, i1Type); -} - ParseResult TransferWriteOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); @@ -3880,7 +3868,7 @@ if (shapedType.getElementType().dyn_cast()) return parser.emitError( maskInfo.location, "does not support masks with vector element type"); - auto maskType = inferTransferWriteMaskType(vectorType, permMap); + auto maskType = inferTransferOpMaskType(vectorType, permMap); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); } @@ -3907,7 +3895,7 @@ VectorType maskType = getMaskType(); auto permutationMap = getPermutationMap(); VectorType inferredMaskType = - maskType ? inferTransferWriteMaskType(vectorType, permutationMap) + maskType ? inferTransferOpMaskType(vectorType, permutationMap) : VectorType(); if (llvm::size(getIndices()) != shapedType.getRank()) @@ -3933,7 +3921,7 @@ /// Returns the mask type expected by this operation. Mostly used for /// verification purposes. Type TransferWriteOp::getExpectedMaskType() { - return inferTransferWriteMaskType(getVectorType(), getPermutationMap()); + return inferTransferOpMaskType(getVectorType(), getPermutationMap()); } /// Fold: Index: mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp @@ -20,15 +20,16 @@ using namespace mlir; using namespace mlir::vector; -/// Transpose a vector transfer op's `in_bounds` attribute according to given -/// indices. +/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse +/// permutation based on the given indices. static ArrayAttr -transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, - const SmallVector &permutation) { - SmallVector newInBoundsValues; +inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, + const SmallVector &permutation) { + SmallVector newInBoundsValues(permutation.size()); + size_t index = 0; for (unsigned pos : permutation) - newInBoundsValues.push_back( - attr.getValue()[pos].cast().getValue()); + newInBoundsValues[pos] = + attr.getValue()[index++].cast().getValue(); return builder.getBoolArrayAttr(newInBoundsValues); } @@ -85,7 +86,7 @@ // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = - op.getInBounds() ? transposeInBoundsAttr( + op.getInBounds() ? inverseTransposeInBoundsAttr( rewriter, op.getInBounds().value(), permutation) : ArrayAttr(); @@ -142,16 +143,17 @@ // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) // comp = (d0, d1, d2) -> (d2, d0, d1) auto comp = compressUnusedDims(map); + AffineMap permutationMap = inversePermutation(comp); // Get positions of remaining result dims. SmallVector indices; - llvm::transform(comp.getResults(), std::back_inserter(indices), + llvm::transform(permutationMap.getResults(), std::back_inserter(indices), [](AffineExpr expr) { return expr.dyn_cast().getPosition(); }); // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = - op.getInBounds() ? transposeInBoundsAttr( + op.getInBounds() ? inverseTransposeInBoundsAttr( rewriter, op.getInBounds().value(), permutation) : ArrayAttr(); Index: mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -337,11 +337,11 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<8x14x16x7xi1> - %mask0 = vector.splat %m : vector<8x14x16x7xi1> + // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1> + %mask0 = vector.splat %m : vector<16x14x7x8xi1> %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor - // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32> - // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor + // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32> + // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[MASK]] {in_bounds = [true, false, true, false]} : vector<16x14x7x8xf32>, tensor vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32> Index: mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir =================================================================== --- mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir +++ mlir/test/Integration/Dialect/Vector/CPU/test-transfer-write.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s @@ -39,6 +39,34 @@ return %r : vector<32xf32> } +func.func @transfer_write_inbounds_3d(%A : memref<4x4x4xf32>) { + %c0 = arith.constant 0: index + %f = arith.constant 0.0 : f32 + %v0 = vector.splat %f : vector<2x3x4xf32> + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %f3 = arith.constant 3.0 : f32 + %f4 = arith.constant 4.0 : f32 + %f5 = arith.constant 5.0 : f32 + %f6 = arith.constant 6.0 : f32 + %f7 = arith.constant 7.0 : f32 + %f8 = arith.constant 8.0 : f32 + + %v1 = vector.insert %f1, %v0[0, 0, 0] : f32 into vector<2x3x4xf32> + %v2 = vector.insert %f2, %v1[0, 0, 3] : f32 into vector<2x3x4xf32> + %v3 = vector.insert %f3, %v2[0, 2, 0] : f32 into vector<2x3x4xf32> + %v4 = vector.insert %f4, %v3[0, 2, 3] : f32 into vector<2x3x4xf32> + %v5 = vector.insert %f5, %v4[1, 0, 0] : f32 into vector<2x3x4xf32> + %v6 = vector.insert %f6, %v5[1, 0, 3] : f32 into vector<2x3x4xf32> + %v7 = vector.insert %f7, %v6[1, 2, 0] : f32 into vector<2x3x4xf32> + %v8 = vector.insert %f8, %v7[1, 2, 3] : f32 into vector<2x3x4xf32> + vector.transfer_write %v8, %A[%c0, %c0, %c0] + {permutation_map = affine_map<(d0, d1, d2) -> (d2, d0, d1)>, + in_bounds = [true, true, true]} + : vector<2x3x4xf32>, memref<4x4x4xf32> + return +} + func.func @entry() { %c0 = arith.constant 0: index %c1 = arith.constant 1: index @@ -90,6 +118,24 @@ vector.print %6 : vector<32xf32> memref.dealloc %A : memref + + // 3D case + %c4 = arith.constant 4: index + %A1 = memref.alloc() {alignment=64} : memref<4x4x4xf32> + scf.for %i = %c0 to %c4 step %c1 { + scf.for %j = %c0 to %c4 step %c1 { + scf.for %k = %c0 to %c4 step %c1 { + %f = arith.constant 0.0: f32 + memref.store %f, %A1[%i, %j, %k] : memref<4x4x4xf32> + } + } + } + call @transfer_write_inbounds_3d(%A1) : (memref<4x4x4xf32>) -> () + %f = arith.constant 0.0: f32 + %r = vector.transfer_read %A1[%c0, %c0, %c0], %f + : memref<4x4x4xf32>, vector<4x4x4xf32> + vector.print %r : vector<4x4x4xf32> + return } @@ -100,3 +146,7 @@ // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 ) // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13 ) + +// 3D case. +// CHECK: ( ( ( 1, 5, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 2, 6, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ), +// CHECK-SAME: ( ( 3, 7, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 4, 8, 0, 0 ) ), ( ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 0, 0 ) ) )