diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp @@ -68,19 +68,16 @@ transferRead.getMask().getDefiningOp()) return TransferMask{maskOp, {}}; - // Case 2: Mask is the result of a vector.extract(vector.create_mask). Only - // 2D -> 1D extracts are supported at the moment. + // Case 2: Mask is the result of a vector.extract(vector.create_mask). if (auto extractOp = transferRead.getMask().getDefiningOp()) if (auto maskOp = extractOp.getVector().getDefiningOp()) - if (extractOp.getPosition().size() == 1 && - extractOp.getSourceVectorType().getRank() == 2) - return TransferMask{maskOp, - SmallVector(extractOp.getPosition())}; + return TransferMask{maskOp, + SmallVector(extractOp.getPosition())}; // All other cases: not supported. - return {}; + return failure(); } /// Build an SSA value that represents the number of read elements. @@ -102,18 +99,27 @@ // vector.extract(vector.create_mask). // If extract_pos < num_ones, take number of elements from the least - // significant dimension. - assert(transferMask->createMaskOp.getVectorType().getRank() == 2 && - "expected 2D mask"); - assert(transferMask->extractPosition.size() == 1 && - "expected 2D->1D extract"); - Value cmp = b.create( - loc, arith::CmpIPredicate::slt, - b.create(loc, - transferMask->extractPosition.front()), - transferMask->createMaskOp->getOperands().front()); + // significant dimension. (Do this for all dimensions and bit-AND the + // conditions.) + assert(transferMask->createMaskOp.getVectorType().getRank() - + transferMask->extractPosition.size() == + 1 && + "expected N-D -> (N-1)-D extract"); + Value cond; + // Note: There is one more `sz` than `pos`. The loop end with the last `pos`. + for (auto [pos, sz] : llvm::zip(transferMask->extractPosition, + transferMask->createMaskOp->getOperands())) { + Value cmp = + b.create(loc, arith::CmpIPredicate::slt, + b.create(loc, pos), sz); + if (!cond) { + cond = cmp; + continue; + } + cond = b.create(loc, cmp, cond); + } return b.create( - loc, cmp, transferMask->createMaskOp->getOperands().back(), + loc, cond, transferMask->createMaskOp->getOperands().back(), b.create(loc, 0)); } diff --git a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir --- a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir +++ b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir @@ -165,10 +165,6 @@ %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space> %c0 = arith.constant 0 : index %cst_0 = arith.constant 0.000000e+00 : f32 - // CHECK: %[[mask:.*]] = vector.create_mask - // CHECK: %[[e0:.*]] = vector.extract %[[mask]][0] : vector<3x4xi1> - // CHECK: %[[e1:.*]] = vector.extract %[[mask]][1] : vector<3x4xi1> - // CHECK: %[[e2:.*]] = vector.extract %[[mask]][2] : vector<3x4xi1> // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]] // CHECK: %[[s0:.*]] = arith.select %[[cmpi0]], %[[sz1]], %[[c0]] @@ -199,3 +195,64 @@ transform.apply_cse to %top_level_func_2 : !transform.any_op } } + +// ----- + +// 3D vector.transfer_read with a mask. +builtin.module { + // CHECK-LABEL: @read_3d_with_mask( + // CHECK-SAME: %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index, %[[a:.*]]: memref<1024x1024x1024xf32> + func.func @read_3d_with_mask(%sz0: index, %sz1: index, %sz2: index, %a: memref<1024x1024x1024xf32>) { + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[c1:.*]] = arith.constant 1 : index + // CHECK: %[[c2:.*]] = arith.constant 2 : index + %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + + // CHECK: %[[cmpi0:.*]] = arith.cmpi slt, %[[c0]], %[[sz0]] + // CHECK: %[[cmpi1:.*]] = arith.cmpi slt, %[[c0]], %[[sz1]] + // CHECK: %[[cond0:.*]] = arith.andi %[[cmpi1]], %[[cmpi0]] + // CHECK: %[[s0:.*]] = arith.select %[[cond0]], %[[sz2]], %[[c0]] + // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s0]] {bypassL1} + + // CHECK: %[[cmpi2:.*]] = arith.cmpi slt, %[[c1]], %[[sz1]] + // CHECK: %[[cond1:.*]] = arith.andi %[[cmpi2]], %[[cmpi0]] + // CHECK: %[[s1:.*]] = arith.select %[[cond1]], %[[sz2]], %[[c0]] + // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s1]] {bypassL1} + + // CHECK: %[[cmpi3:.*]] = arith.cmpi slt, %[[c2]], %[[sz1]] + // CHECK: %[[cond2:.*]] = arith.andi %[[cmpi3]], %[[cmpi0]] + // CHECK: %[[s2:.*]] = arith.select %[[cond2]], %[[sz2]], %[[c0]] + // CHECK: nvgpu.device_async_copy %[[a]][%[[c0]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s2]] {bypassL1} + + // CHECK: %[[cmpi4:.*]] = arith.cmpi slt, %[[c1]], %[[sz0]] + // CHECK: %[[cond3:.*]] = arith.andi %[[cmpi1]], %[[cmpi4]] + // CHECK: %[[s3:.*]] = arith.select %[[cond3]], %[[sz2]], %[[c0]] + // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c0]], %[[c0]]], {{.*}}, 4, %[[s3]] {bypassL1} + + // CHECK: %[[cond4:.*]] = arith.andi %[[cmpi2]], %[[cmpi4]] + // CHECK: %[[s4:.*]] = arith.select %[[cond4]], %[[sz2]], %[[c0]] + // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c1]], %[[c0]]], {{.*}}, 4, %[[s4]] {bypassL1} + + // CHECK: %[[cond5:.*]] = arith.andi %[[cmpi3]], %[[cmpi4]] + // CHECK: %[[s5:.*]] = arith.select %[[cond5]], %[[sz2]], %[[c0]] + // CHECK: nvgpu.device_async_copy %[[a]][%[[c1]], %[[c2]], %[[c0]]], {{.*}}, 4, %[[s5]] {bypassL1} + %mask = vector.create_mask %sz0, %sz1, %sz2 : vector<2x3x4xi1> + %1 = vector.transfer_read %a[%c0, %c0, %c0], %cst_0, %mask {in_bounds = [true, true, true]} : memref<1024x1024x1024xf32>, vector<2x3x4xf32> + vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<2x3x4xf32>, memref<4x32x16xf32, #gpu.address_space> + + return + } + + transform.sequence failures(propagate) { + ^bb1(%variant_op: !transform.any_op): + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %top_level_func { + transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true + } : !transform.any_op + transform.nvgpu.create_async_groups %top_level_func {bypass_l1} : (!transform.any_op) -> (!transform.any_op) + %top_level_func_2 = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_cse to %top_level_func_2 : !transform.any_op + } +}