diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.h b/mlir/include/mlir/Interfaces/VectorInterfaces.h --- a/mlir/include/mlir/Interfaces/VectorInterfaces.h +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.h @@ -17,6 +17,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +namespace mlir { +namespace vector { +namespace detail { + +/// Given the vector type and the permutation map of a vector transfer op, +/// compute the expected mask type. +VectorType transferMaskType(VectorType vecType, AffineMap map); + +} // namespace detail +} // namespace vector +} // namespace mlir + /// Include the generated interface declarations. #include "mlir/Interfaces/VectorInterfaces.h.inc" diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -156,6 +156,19 @@ return $_op.vector().getType().template dyn_cast(); }] >, + InterfaceMethod< + /*desc=*/"Return the mask type if the op has a mask.", + /*retTy=*/"VectorType", + /*methodName=*/"getMaskType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.mask() + ? mlir::vector::detail::transferMaskType( + $_op.getVectorType(), $_op.permutation_map()) + : VectorType(); + }] + >, InterfaceMethod< /*desc=*/[{ Return the number of dimensions that participate in the permutation map.}], diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -79,13 +79,20 @@ if (xferOp.mask()) { auto maskType = MemRefType::get({}, xferOp.mask().getType()); - result.maskBuffer = memref_alloca(maskType).value; - memref_store(xferOp.mask(), result.maskBuffer); + auto maskBuffer = memref_alloca(maskType).value; + memref_store(xferOp.mask(), maskBuffer); + result.maskBuffer = memref_load(maskBuffer); } return result; } +template +static bool isOutermostDimBroadcast(OpTy xferOp) { + auto map = xferOp.permutation_map(); + return map.getResult(0).template isa(); +} + /// Given a vector transfer op, calculate which dimension of the `source` /// memref should be unpacked in the next application of TransferOpConversion. /// A return value of None indicates a broadcast. @@ -95,7 +102,7 @@ if (auto expr = map.getResult(0).template dyn_cast()) { return expr.getPosition(); } - assert(map.getResult(0).template isa() && + assert(isOutermostDimBroadcast(xferOp) && "Expected AffineDimExpr or AffineConstantExpr"); return None; } @@ -143,14 +150,19 @@ } /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask -/// is set to true. Does not return a Value if the transfer op is not 1D or -/// if the transfer op does not have a mask. +/// is set to true. No such check is generated under following circumstances: +/// * xferOp does not have a mask. +/// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is +/// computed and attached to the new transfer op in the pattern.) +/// * The to-be-unpacked dim of xferOp is a broadcast. template -static Value maybeGenerateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) { - if (xferOp.getVectorType().getRank() != 1) - return Value(); +static Value generateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) { if (!xferOp.mask()) return Value(); + if (xferOp.getMaskType().getRank() != 1) + return Value(); + if (isOutermostDimBroadcast(xferOp)) + return Value(); auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); return vector_extract_element(xferOp.mask(), ivI32).value; @@ -200,7 +212,7 @@ } // Condition check 2: Masked in? - if (auto maskCond = maybeGenerateMaskCheck(builder, xferOp, iv)) { + if (auto maskCond = generateMaskCheck(builder, xferOp, iv)) { if (cond) { cond = builder.create(xferOp.getLoc(), cond, maskCond); } else { @@ -488,8 +500,8 @@ auto *newXfer = rewriter.clone(*xferOp.getOperation()); newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); if (xferOp.mask()) { - auto loadedMask = memref_load(buffers.maskBuffer); - dyn_cast(newXfer).maskMutable().assign(loadedMask); + dyn_cast(newXfer).maskMutable().assign( + buffers.maskBuffer); } memref_store(newXfer->getResult(0), buffers.dataBuffer); @@ -541,9 +553,8 @@ }); if (xferOp.mask()) { - auto loadedMask = memref_load(buffers.maskBuffer); rewriter.updateRootInPlace( - xferOp, [&]() { xferOp.maskMutable().assign(loadedMask); }); + xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); }); } return success(); @@ -590,8 +601,18 @@ auto maskBuffer = getMaskBuffer(xferOp); auto maskBufferType = maskBuffer.getType().template dyn_cast(); - auto castedMaskType = unpackOneDim(maskBufferType); - castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer); + if (isOutermostDimBroadcast(xferOp) || + xferOp.getMaskType().getRank() == 1) { + // Do not unpack a dimension of the mask, if: + // * To-be-unpacked transfer op dimension is a broadcast. + // * Mask is 1D, i.e., the mask cannot be further unpacked. + // (That means that all remaining dimensions of the transfer op must + // be broadcasted.) + castedMaskBuffer = maskBuffer; + } else { + auto castedMaskType = unpackOneDim(maskBufferType); + castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer); + } } // Loop bounds and step. @@ -616,13 +637,20 @@ Strategy::rewriteOp(b, xferOp, castedDataBuffer, iv); // If old transfer op has a mask: Set mask on new transfer op. - if (xferOp.mask()) { + // Special case: If the mask of the old transfer op is 1D and the + // unpacked dim is not a broadcast, no mask is needed + // on the new transfer op. + if (xferOp.mask() && (isOutermostDimBroadcast(xferOp) || + xferOp.getMaskType().getRank() > 1)) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(newXfer); // Insert load before newXfer. SmallVector loadIndices; Strategy::getBufferIndices(xferOp, loadIndices); - loadIndices.push_back(iv); + // In case of broadcast: Use same indices to load from memref as + // before. + if (!isOutermostDimBroadcast(xferOp)) + loadIndices.push_back(iv); auto mask = memref_load(castedMaskBuffer, loadIndices); rewriter.updateRootInPlace( @@ -661,7 +689,7 @@ return dim; } - assert(map.getResult(0).template isa() && + assert(isOutermostDimBroadcast(xferOp) && "Expected AffineDimExpr or AffineConstantExpr"); return None; } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2306,6 +2306,7 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType, VectorType vectorType, + VectorType maskType, AffineMap permutationMap, ArrayAttr inBounds) { if (op->hasAttr("masked")) { @@ -2341,6 +2342,9 @@ if (permutationMap.getNumResults() != rankOffset) return op->emitOpError("requires a permutation_map with result dims of " "the same rank as the vector type"); + + if (maskType) + return op->emitOpError("does not support masks with vector element type"); } else { // Memref or tensor has scalar element type. unsigned resultVecSize = @@ -2355,6 +2359,13 @@ if (permutationMap.getNumResults() != vectorType.getRank()) return op->emitOpError("requires a permutation_map with result dims of " "the same rank as the vector type"); + + VectorType expectedMaskType = + vector::detail::transferMaskType(vectorType, permutationMap); + if (maskType && expectedMaskType != maskType) + return op->emitOpError("expects mask type consistent with permutation " + "map: ") + << maskType; } if (permutationMap.getNumSymbols() != 0) @@ -2491,10 +2502,11 @@ if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); auto permutationAttrName = TransferReadOp::getPermutationMapAttrName(); - auto attr = result.attributes.get(permutationAttrName); - if (!attr) { + Attribute mapAttr = result.attributes.get(permutationAttrName); + if (!mapAttr) { auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); - result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap)); + mapAttr = AffineMapAttr::get(permMap); + result.attributes.set(permutationAttrName, mapAttr); } if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands) || @@ -2502,7 +2514,13 @@ result.operands)) return failure(); if (hasMask.succeeded()) { - auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type()); + if (shapedType.getElementType().dyn_cast()) + return parser.emitError( + maskInfo.location, "does not support masks with vector element type"); + auto map = mapAttr.dyn_cast().getValue(); + // Instead of adding the mask type as an op type, compute it based on the + // vector type and the permutation map (to keep the type signature small). + auto maskType = mlir::vector::detail::transferMaskType(vectorType, map); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); } @@ -2517,6 +2535,7 @@ // Consistency of elemental types in source and vector. ShapedType shapedType = op.getShapedType(); VectorType vectorType = op.getVectorType(); + VectorType maskType = op.getMaskType(); auto paddingType = op.padding().getType(); auto permutationMap = op.permutation_map(); auto sourceElementType = shapedType.getElementType(); @@ -2525,7 +2544,7 @@ return op.emitOpError("requires ") << shapedType.getRank() << " indices"; if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType, - permutationMap, + maskType, permutationMap, op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) return failure(); @@ -2768,6 +2787,9 @@ parser.resolveOperands(indexInfo, indexType, result.operands)) return failure(); if (hasMask.succeeded()) { + if (shapedType.getElementType().dyn_cast()) + return parser.emitError( + maskInfo.location, "does not support masks with vector element type"); auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type()); if (parser.resolveOperand(maskInfo, maskType, result.operands)) return failure(); @@ -2793,6 +2815,7 @@ // Consistency of elemental types in shape and vector. ShapedType shapedType = op.getShapedType(); VectorType vectorType = op.getVectorType(); + VectorType maskType = op.getMaskType(); auto permutationMap = op.permutation_map(); if (llvm::size(op.indices()) != shapedType.getRank()) @@ -2804,7 +2827,7 @@ return op.emitOpError("should not have broadcast dimensions"); if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType, - permutationMap, + maskType, permutationMap, op.in_bounds() ? *op.in_bounds() : ArrayAttr()))) return failure(); diff --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp --- a/mlir/lib/Interfaces/VectorInterfaces.cpp +++ b/mlir/lib/Interfaces/VectorInterfaces.cpp @@ -10,6 +10,19 @@ using namespace mlir; +VectorType mlir::vector::detail::transferMaskType(VectorType vecType, + AffineMap map) { + auto i1Type = IntegerType::get(map.getContext(), 1); + SmallVector shape; + for (int64_t i = 0; i < vecType.getRank(); ++i) { + // Only result dims have a corresponding dim in the mask. + if (auto expr = map.getResult(i).template isa()) { + shape.push_back(vecType.getDimSize(i)); + } + } + return shape.empty() ? VectorType() : VectorType::get(shape, i1Type); +} + //===----------------------------------------------------------------------===// // VectorUnroll Interfaces //===----------------------------------------------------------------------===// diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -5,6 +5,14 @@ // Test for special cases of 1D vector transfer ops. +memref.global "private" @gv : memref<5x6xf32> = + dense<[[0. , 1. , 2. , 3. , 4. , 5. ], + [10., 11., 12., 13., 14., 15.], + [20., 21., 22., 23., 24., 25.], + [30., 31., 32., 33., 34., 35.], + [40., 41., 42., 43., 44., 45.]]> + +// Non-contiguous, strided load. func @transfer_read_1d(%A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 %f = vector.transfer_read %A[%base1, %base2], %fm42 @@ -14,6 +22,7 @@ return } +// Broadcast. func @transfer_read_1d_broadcast( %A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 @@ -24,6 +33,7 @@ return } +// Non-contiguous, strided load. func @transfer_read_1d_in_bounds( %A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 @@ -34,6 +44,7 @@ return } +// Non-contiguous, strided load. func @transfer_read_1d_mask( %A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 @@ -45,6 +56,7 @@ return } +// Non-contiguous, strided load. func @transfer_read_1d_mask_in_bounds( %A : memref, %base1 : index, %base2 : index) { %fm42 = constant -42.0: f32 @@ -56,6 +68,7 @@ return } +// Non-contiguous, strided store. func @transfer_write_1d(%A : memref, %base1 : index, %base2 : index) { %fn1 = constant -1.0 : f32 %vf0 = splat %fn1 : vector<7xf32> @@ -65,57 +78,68 @@ return } +// Non-contiguous, strided store. +func @transfer_write_1d_mask(%A : memref, %base1 : index, %base2 : index) { + %fn1 = constant -2.0 : f32 + %vf0 = splat %fn1 : vector<7xf32> + %mask = constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1> + vector.transfer_write %vf0, %A[%base1, %base2], %mask + {permutation_map = affine_map<(d0, d1) -> (d0)>} + : vector<7xf32>, memref + return +} + func @entry() { %c0 = constant 0: index %c1 = constant 1: index %c2 = constant 2: index %c3 = constant 3: index - %f10 = constant 10.0: f32 - // work with dims of 4, not of 3 - %first = constant 5: index - %second = constant 6: index - %A = memref.alloc(%first, %second) : memref - scf.for %i = %c0 to %first step %c1 { - %i32 = index_cast %i : index to i32 - %fi = sitofp %i32 : i32 to f32 - %fi10 = mulf %fi, %f10 : f32 - scf.for %j = %c0 to %second step %c1 { - %j32 = index_cast %j : index to i32 - %fj = sitofp %j32 : i32 to f32 - %fres = addf %fi10, %fj : f32 - memref.store %fres, %A[%i, %j] : memref - } - } - - // Read from 2D memref on first dimension. Cannot be lowered to an LLVM - // vector load. Instead, generates scalar loads. + %0 = memref.get_global @gv : memref<5x6xf32> + %A = memref.cast %0 : memref<5x6xf32> to memref + + // 1. Read from 2D memref on first dimension. Cannot be lowered to an LLVM + // vector load. Instead, generates scalar loads. call @transfer_read_1d(%A, %c1, %c2) : (memref, index, index) -> () - // Write to 2D memref on first dimension. Cannot be lowered to an LLVM - // vector store. Instead, generates scalar stores. + // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 ) + + // 2. Write to 2D memref on first dimension. Cannot be lowered to an LLVM + // vector store. Instead, generates scalar stores. call @transfer_write_1d(%A, %c3, %c2) : (memref, index, index) -> () - // (Same as above.) + + // 3. (Same as 1. To check if 2 works correctly.) call @transfer_read_1d(%A, %c0, %c2) : (memref, index, index) -> () - // Read a scalar from a 2D memref and broadcast the value to a 1D vector. - // Generates a loop with vector.insertelement. + // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) + + // 4. Read a scalar from a 2D memref and broadcast the value to a 1D vector. + // Generates a loop with vector.insertelement. call @transfer_read_1d_broadcast(%A, %c1, %c2) : (memref, index, index) -> () - // Read from 2D memref on first dimension. Accesses are in-bounds, so no - // if-check is generated inside the generated loop. + // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) + + // 5. Read from 2D memref on first dimension. Accesses are in-bounds, so no + // if-check is generated inside the generated loop. call @transfer_read_1d_in_bounds(%A, %c1, %c2) : (memref, index, index) -> () - // Optional mask attribute is specified and, in addition, there may be - // out-of-bounds accesses. + // CHECK: ( 12, 22, -1 ) + + // 6. Optional mask attribute is specified and, in addition, there may be + // out-of-bounds accesses. call @transfer_read_1d_mask(%A, %c1, %c2) : (memref, index, index) -> () - // Same as above, but accesses are in-bounds. + // CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 ) + + // 7. Same as 6, but accesses are in-bounds. call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2) : (memref, index, index) -> () + // CHECK: ( 12, -42, -1 ) + + // 8. Write to 2D memref on first dimension with a mask. + call @transfer_write_1d_mask(%A, %c1, %c0) + : (memref, index, index) -> () + + // 9. (Same as 1. To check if 8 works correctly.) + call @transfer_read_1d(%A, %c0, %c0) : (memref, index, index) -> () + // CHECK: ( 0, -2, 20, -2, 40, -42, -42, -42, -42 ) + return } - -// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 ) -// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) -// CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) -// CHECK: ( 12, 22, -1 ) -// CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 ) -// CHECK: ( 12, -42, -1 ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -3,6 +3,11 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s +memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ], + [10., 11., 12., 13.], + [20., 21., 22., 23.]]> + +// Vector load. func @transfer_read_2d(%A : memref, %base1: index, %base2: index) { %fm42 = constant -42.0: f32 %f = vector.transfer_read %A[%base1, %base2], %fm42 @@ -12,6 +17,7 @@ return } +// Vector load with mask. func @transfer_read_2d_mask(%A : memref, %base1: index, %base2: index) { %fm42 = constant -42.0: f32 %mask = constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1], @@ -25,6 +31,47 @@ return } +// Vector load with mask + transpose. +func @transfer_read_2d_mask_transposed( + %A : memref, %base1: index, %base2: index) { + %fm42 = constant -42.0: f32 + %mask = constant dense<[[1, 0, 1, 0], [0, 0, 1, 0], + [1, 1, 1, 1], [0, 1, 1, 0], + [1, 1, 1, 1], [1, 1, 1, 1], + [1, 1, 1, 1], [0, 0, 0, 0], + [1, 1, 1, 1]]> : vector<9x4xi1> + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : + memref, vector<9x4xf32> + vector.print %f: vector<9x4xf32> + return +} + +// Vector load with mask + broadcast. +func @transfer_read_2d_mask_broadcast( + %A : memref, %base1: index, %base2: index) { + %fm42 = constant -42.0: f32 + %mask = constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1> + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + {permutation_map = affine_map<(d0, d1) -> (0, d1)>} : + memref, vector<4x9xf32> + vector.print %f: vector<4x9xf32> + return +} + +// Transpose + vector load with mask + broadcast. +func @transfer_read_2d_mask_transpose_broadcast_last_dim( + %A : memref, %base1: index, %base2: index) { + %fm42 = constant -42.0: f32 + %mask = constant dense<[1, 0, 1, 1]> : vector<4xi1> + %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask + {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} : + memref, vector<4x9xf32> + vector.print %f: vector<4x9xf32> + return +} + +// Load + transpose. func @transfer_read_2d_transposed( %A : memref, %base1: index, %base2: index) { %fm42 = constant -42.0: f32 @@ -35,6 +82,7 @@ return } +// Load 1D + broadcast to 2D. func @transfer_read_2d_broadcast( %A : memref, %base1: index, %base2: index) { %fm42 = constant -42.0: f32 @@ -45,6 +93,7 @@ return } +// Vector store. func @transfer_write_2d(%A : memref, %base1: index, %base2: index) { %fn1 = constant -1.0 : f32 %vf0 = splat %fn1 : vector<1x4xf32> @@ -54,55 +103,79 @@ return } +// Vector store with mask. +func @transfer_write_2d_mask(%A : memref, %base1: index, %base2: index) { + %fn1 = constant -2.0 : f32 + %mask = constant dense<[[1, 0, 1, 0]]> : vector<1x4xi1> + %vf0 = splat %fn1 : vector<1x4xf32> + vector.transfer_write %vf0, %A[%base1, %base2], %mask + {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : + vector<1x4xf32>, memref + return +} + func @entry() { %c0 = constant 0: index %c1 = constant 1: index %c2 = constant 2: index %c3 = constant 3: index - %c4 = constant 4: index - %c5 = constant 5: index - %c8 = constant 5: index - %f10 = constant 10.0: f32 - // work with dims of 4, not of 3 - %first = constant 3: index - %second = constant 4: index - %A = memref.alloc(%first, %second) : memref - scf.for %i = %c0 to %first step %c1 { - %i32 = index_cast %i : index to i32 - %fi = sitofp %i32 : i32 to f32 - %fi10 = mulf %fi, %f10 : f32 - scf.for %j = %c0 to %second step %c1 { - %j32 = index_cast %j : index to i32 - %fj = sitofp %j32 : i32 to f32 - %fres = addf %fi10, %fj : f32 - memref.store %fres, %A[%i, %j] : memref - } - } - // On input, memory contains [[ 0, 1, 2, ...], [10, 11, 12, ...], ...] - // Read shifted by 2 and pad with -42: + %0 = memref.get_global @gv : memref<3x4xf32> + %A = memref.cast %0 : memref<3x4xf32> to memref + + // 1. Read 2D vector from 2D memref. call @transfer_read_2d(%A, %c1, %c2) : (memref, index, index) -> () - // Same as above, but transposed + // CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) + + // 2. Read 2D vector from 2D memref at specified location and transpose the + // result. call @transfer_read_2d_transposed(%A, %c1, %c2) : (memref, index, index) -> () - // Write into memory shifted by 3 - call @transfer_write_2d(%A, %c3, %c1) : (memref, index, index) -> () - // Read shifted by 0 and pad with -42: - call @transfer_read_2d(%A, %c0, %c0) : (memref, index, index) -> () - // Same as above, but apply a mask + // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) + + // 3. Read 2D vector from 2D memref with a 2D mask. In addition, some + // accesses are out-of-bounds. call @transfer_read_2d_mask(%A, %c0, %c0) : (memref, index, index) -> () - // Same as above, but without mask and transposed - call @transfer_read_2d_transposed(%A, %c0, %c0) + // CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) + + // 4. Same as 3, but transpose the result. + call @transfer_read_2d_mask_transposed(%A, %c0, %c0) : (memref, index, index) -> () - // Second vector dimension is a broadcast + // CHECK: ( ( 0, -42, 20, -42 ), ( -42, -42, 21, -42 ), ( 2, 12, 22, -42 ), ( -42, 13, 23, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ) ) + + // 5. Read 1D vector from 2D memref at specified location and broadcast the + // result to 2D. call @transfer_read_2d_broadcast(%A, %c1, %c2) : (memref, index, index) -> () + // CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) + + // 6. Read 1D vector from 2D memref at specified location with mask and + // broadcast the result to 2D. + call @transfer_read_2d_mask_broadcast(%A, %c2, %c1) + : (memref, index, index) -> () + // CHECK: ( ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ) ) + + // 7. Read 1D vector from 2D memref (second dimension) at specified location + // with mask and broadcast the result to 2D. In this test case, mask + // elements must be evaluated before lowering to an (N>1)-D transfer. + call @transfer_read_2d_mask_transpose_broadcast_last_dim(%A, %c0, %c1) + : (memref, index, index) -> () + // CHECK: ( ( 1, 1, 1, 1, 1, 1, 1, 1, 1 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( 3, 3, 3, 3, 3, 3, 3, 3, 3 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) + + // 8. Write 2D vector into 2D memref at specified location. + call @transfer_write_2d(%A, %c1, %c2) : (memref, index, index) -> () + + // 9. Read memref to verify step 8. + call @transfer_read_2d(%A, %c0, %c0) : (memref, index, index) -> () + // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) + + // 10. Write 2D vector into 2D memref at specified location with mask. + call @transfer_write_2d_mask(%A, %c0, %c2) : (memref, index, index) -> () + + // 11. Read memref to verify step 10. + call @transfer_read_2d(%A, %c0, %c0) : (memref, index, index) -> () + // CHECK: ( ( 0, 1, -2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) + return } -// CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) -// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) -// CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) -// CHECK: ( ( 0, -42, 2, -42, -42, -42, -42, -42, -42 ), ( -42, -42, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) -// CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) ) -// CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -1,15 +1,8 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s - // RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -// Test case is based on test-transfer-read-2d. - func @transfer_read_3d(%A : memref, %o: index, %a: index, %b: index, %c: index) { %fm42 = constant -42.0: f32 @@ -29,6 +22,17 @@ return } +func @transfer_read_3d_mask_broadcast( + %A : memref, %o: index, %a: index, %b: index, %c: index) { + %fm42 = constant -42.0: f32 + %mask = constant dense<[0, 1]> : vector<2xi1> + %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42, %mask + {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>} + : memref, vector<2x5x3xf32> + vector.print %f: vector<2x5x3xf32> + return +} + func @transfer_read_3d_transposed(%A : memref, %o: index, %a: index, %b: index, %c: index) { %fm42 = constant -42.0: f32 @@ -80,20 +84,34 @@ } } + // 1. Read 3D vector from 4D memref. call @transfer_read_3d(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () + // CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) ) + + // 2. Write 3D vector to 4D memref. call @transfer_write_3d(%A, %c0, %c0, %c1, %c1) : (memref, index, index, index, index) -> () + + // 3. Read memref to verify step 2. call @transfer_read_3d(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () + // CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) ) + + // 4. Read 3D vector from 4D memref and transpose vector. call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () + // CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) ) + + // 5. Read 1D vector from 4D memref and broadcast vector to 3D. call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () + // CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) ) + + // 6. Read 1D vector from 4D memref with mask and broadcast vector to 3D. + call @transfer_read_3d_mask_broadcast(%A, %c0, %c0, %c0, %c0) + : (memref, index, index, index, index) -> () + // CHECK: ( ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ), ( ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ) ) ) + return } - -// CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) ) -// CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) ) -// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) ) -// CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) )