diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp @@ -46,23 +46,79 @@ return isa(read); } +namespace { +/// A vector.create_mask op and extract position. +struct TransferMask { + vector::CreateMaskOp createMaskOp; + SmallVector extractPosition; +}; +} // namespace + /// If the given vector load op has a mask that is defined by /// vector.create_mask, return that op. -static vector::CreateMaskOp getMaskOp(Operation *loadOp) { +static FailureOr getMaskOp(Operation *loadOp) { auto transferRead = dyn_cast(loadOp); if (!transferRead || !transferRead.getMask()) - return {}; - auto maskOp = transferRead.getMask().getDefiningOp(); - // TODO: Support 2D masks and higher. Ops with a >1D mask are ignored at the - // moment. - if (maskOp.getVectorType().getRank() != 1) - return {}; - return maskOp; + return TransferMask{{}, {}}; + assert(transferRead.getMask().getType().getRank() == 1 && + "expected 1-D mask"); + + // Case 1: Mask is the result of a vector.create_mask. + if (auto maskOp = + transferRead.getMask().getDefiningOp()) + return TransferMask{maskOp, {}}; + + // Case 2: Mask is the result of a vector.extract(vector.create_mask). Only + // 2D -> 1D extracts are supported at the moment. + if (auto extractOp = + transferRead.getMask().getDefiningOp()) + if (auto maskOp = + extractOp.getVector().getDefiningOp()) + if (extractOp.getPosition().size() == 1 && + extractOp.getSourceVectorType().getRank() == 2) + return TransferMask{maskOp, + SmallVector(extractOp.getPosition())}; + + // All other cases: not supported. + return {}; +} + +/// Build an SSA value that represents the number of read elements. +static Value buildNumReadElements(OpBuilder &b, Location loc, + Operation *readOp) { + FailureOr transferMask = getMaskOp(readOp); + assert(succeeded(transferMask) && "invalid transfer mask"); + + // No mask => no num_read_elements. + if (!transferMask->createMaskOp) + return Value(); + + // No extract: return size of "ones" segment in the mask. + if (transferMask->extractPosition.empty()) { + assert(transferMask->createMaskOp.getNumOperands() == 1 && + "expected single operand"); + return transferMask->createMaskOp.getOperand(0); + } + + // vector.extract(vector.create_mask). + // If extract_pos < num_ones, take number of elements from the least + // significant dimension. + assert(transferMask->createMaskOp.getVectorType().getRank() == 2 && + "expected 2D mask"); + assert(transferMask->extractPosition.size() == 1 && + "expected 2D->1D extract"); + Value cmp = b.create( + loc, arith::CmpIPredicate::slt, + b.create(loc, + transferMask->extractPosition.front()), + transferMask->createMaskOp->getOperands().front()); + return b.create( + loc, cmp, transferMask->createMaskOp->getOperands().back(), + b.create(loc, 0)); } /// Return "true" if the conversion to async copy is supported by "async copy". static bool resultsInSupportedAsyncCopy(MemRefType memrefType, - Operation::operand_range indices, VectorType vecType) { assert(vecType.getRank() == 1 && "expected 1-D vector"); constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16}; @@ -121,7 +177,7 @@ if (getConstantIntValue(transferRead.getPadding()) == static_cast(0)) return; - if (!getMaskOp(readOp)) + if (failed(getMaskOp(readOp))) return; } } @@ -131,9 +187,9 @@ VectorType vecType = cast(vectorVal.getType()); if (!resultsInSupportedAsyncCopy(cast(loadBase.getType()), - nvgpu::getIndices(readOp), vecType) || + vecType) || !resultsInSupportedAsyncCopy(cast(storeBase.getType()), - nvgpu::getIndices(writeOp), vecType)) + vecType)) return; copyToSharedMem.insert(writeOp); @@ -184,11 +240,8 @@ Operation *readOp = vectorVal.getDefiningOp(); Value storeBase = nvgpu::getMemrefOperand(writeOp); Value loadBase = nvgpu::getMemrefOperand(readOp); - Value numReadElements; - if (vector::CreateMaskOp maskOp = getMaskOp(readOp)) { - assert(maskOp.getNumOperands() == 1 && "expected single operand"); - numReadElements = maskOp.getOperand(0); - } + Value numReadElements = + buildNumReadElements(rewriter, writeOp->getLoc(), readOp); auto dstMemref = cast(storeBase.getType()); int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * numElements) / 8; diff --git a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir --- a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir +++ b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir @@ -151,3 +151,51 @@ transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op) } } + +// ----- + +// 2D vector.transfer_read with a mask. +builtin.module { + // CHECK-LABEL: @read_2d_with_mask( + // CHECK-SAME: %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[a:.*]]: memref<1024x1024xf32> + func.func @read_2d_with_mask(%sz0: index, %sz1: index, %a: memref<1024x1024xf32>) { + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[c1:.*]] = arith.constant 1 : index + // CHECK: %[[c2:.*]] = arith.constant 2 : index + %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + // CHECK: %[[mask:.*]] = vector.create_mask + // CHECK: %[[e0:.*]] = vector.extract %[[mask]][0] : vector<3x4xi1> + // CHECK: %[[e1:.*]] = vector.extract %[[mask]][1] : vector<3x4xi1> + // CHECK: %[[e2:.*]] = vector.extract %[[mask]][2] : vector<3x4xi1> + + // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]] + // CHECK: %[[s0:.*]] = arith.select %[[cmpi0]], %[[sz1]], %[[c0]] + // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1} + + // CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]] + // CHECK: %[[s1:.*]] = arith.select %[[cmpi1]], %[[sz1]], %[[c0]] + // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1} + + // CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c2]], %[[sz0]] + // CHECK: %[[s2:.*]] = arith.select %[[cmpi2]], %[[sz1]], %[[c0]] + // CHECK: nvgpu.device_async_copy %[[a]][%[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1} + %mask = vector.create_mask %sz0, %sz1 : vector<3x4xi1> + %1 = vector.transfer_read %a[%c0, %c0], %cst_0, %mask {in_bounds = [true, true]} : memref<1024x1024xf32>, vector<3x4xf32> + vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<3x4xf32>, memref<4x32x16xf32, #gpu.address_space> + + return + } + + transform.sequence failures(propagate) { + ^bb1(%variant_op: !transform.any_op): + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %top_level_func { + transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true + } : !transform.any_op + transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op) + %top_level_func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_cse to %top_level_func_2 : !transform.any_op + } +}