diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2259,22 +2259,21 @@ Location loc = xferOp->getLoc(); VectorType vtp = xferOp.getVectorType(); - // * Create a vector with linear indices [ 0 .. vector_length - 1 ]. - // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. - // * Let dim the memref dimension, compute the vector comparison mask - // (in-bounds mask): - // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] + // Create the in-bounds mask with all elements between [0 .. dim - offset) + // set and [dim - offset .. vector_length) unset. // // TODO: when the leaf transfer rank is k > 1, we need the last `k` // dimensions here. - unsigned vecWidth = vtp.getNumElements(); unsigned lastIndex = llvm::size(xferOp.indices()) - 1; Value off = xferOp.indices()[lastIndex]; Value dim = vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex); - Value mask = buildVectorComparison(rewriter, xferOp, indexOptimizations, - vecWidth, dim, &off); - + Value b = rewriter.create(loc, dim.getType(), dim, off); + Value mask = rewriter.create( + loc, + VectorType::get(vtp.getShape(), rewriter.getI1Type(), + vtp.getNumScalableDims()), + b); if (xferOp.mask()) { // Intersect the in-bounds with the mask specified as an op parameter. mask = rewriter.create(loc, mask, xferOp.mask()); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir @@ -25,16 +25,26 @@ } // CMP32-LABEL: @transfer_read_1d +// CMP32: %[[MEM:.*]]: memref, %[[OFF:.*]]: index) -> vector<16xf32> { +// CMP32: %[[D:.*]] = memref.dim %[[MEM]], %{{.*}} : memref +// CMP32: %[[S:.*]] = arith.subi %[[D]], %[[OFF]] : index // CMP32: %[[C:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> -// CMP32: %[[A:.*]] = arith.addi %{{.*}}, %[[C]] : vector<16xi32> -// CMP32: %[[M:.*]] = arith.cmpi slt, %[[A]], %{{.*}} : vector<16xi32> +// CMP32: %[[B:.*]] = arith.index_cast %[[S]] : index to i32 +// CMP32: %[[B0:.*]] = llvm.insertelement %[[B]], %{{.*}} : vector<16xi32> +// CMP32: %[[BV:.*]] = llvm.shufflevector %[[B0]], {{.*}} : vector<16xi32>, vector<16xi32> +// CMP32: %[[M:.*]] = arith.cmpi slt, %[[C]], %[[BV]] : vector<16xi32> // CMP32: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}} // CMP32: return %[[L]] : vector<16xf32> -// CMP64-LABEL: @transfer_read_1d +// CMP64-LABEL: @transfer_read_1d( +// CMP64: %[[MEM:.*]]: memref, %[[OFF:.*]]: index) -> vector<16xf32> { +// CMP64: %[[D:.*]] = memref.dim %[[MEM]], %{{.*}} : memref +// CMP64: %[[S:.*]] = arith.subi %[[D]], %[[OFF]] : index // CMP64: %[[C:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi64> -// CMP64: %[[A:.*]] = arith.addi %{{.*}}, %[[C]] : vector<16xi64> -// CMP64: %[[M:.*]] = arith.cmpi slt, %[[A]], %{{.*}} : vector<16xi64> +// CMP64: %[[B:.*]] = arith.index_cast %[[S]] : index to i64 +// CMP64: %[[B0:.*]] = llvm.insertelement %[[B]], %{{.*}} : vector<16xi64> +// CMP64: %[[BV:.*]] = llvm.shufflevector %[[B0]], {{.*}} : vector<16xi64>, vector<16xi64> +// CMP64: %[[M:.*]] = arith.cmpi slt, %[[C]], %[[BV]] : vector<16xi64> // CMP64: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}} // CMP64: return %[[L]] : vector<16xf32> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1245,34 +1245,33 @@ return %f: vector<17xf32> } // CHECK-LABEL: func @transfer_read_1d -// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32> -// CHECK: %[[c7:.*]] = arith.constant 7.0 +// CHECK-SAME: %[[MEM:.*]]: memref, +// CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32> +// CHECK: %[[C7:.*]] = arith.constant 7.0 +// +// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref +// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref +// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index // -// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ]. +// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // CHECK: %[[linearIndex:.*]] = arith.constant dense // CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : // CHECK-SAME: vector<17xi32> // -// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. -// CHECK: %[[otrunc:.*]] = arith.index_cast %[[BASE]] : index to i32 -// CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[otrunc]] -// CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]] -// CHECK: %[[offsetVec2:.*]] = arith.addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32> -// -// 3. Let dim the memref dimension, compute the vector comparison mask: -// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] -// CHECK: %[[dtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32 -// CHECK: %[[dimVecInsert:.*]] = llvm.insertelement %[[dtrunc]] -// CHECK: %[[dimVec:.*]] = llvm.shufflevector %[[dimVecInsert]] -// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32> +// 3. Create bound vector to compute in-bound mask: +// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] +// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32 +// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]] +// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]] +// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] +// CHECK-SAME: : vector<17xi32> // // 4. Create pass-through vector. // CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32> // // 5. Bitcast to vector form. -// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : +// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} : // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr // CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : // CHECK-SAME: !llvm.ptr to !llvm.ptr> @@ -1280,21 +1279,24 @@ // 6. Rewrite as a masked read. // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]], // CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} : -// CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xf32>) -> vector<17xf32> // -// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ]. +// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset) +// CHECK: %[[C0_b:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref +// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index +// +// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // CHECK: %[[linearIndex_b:.*]] = arith.constant dense // CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : // CHECK-SAME: vector<17xi32> // -// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. -// CHECK: llvm.shufflevector %{{.*}} : vector<17xi32> -// CHECK: arith.addi -// -// 3. Let dim the memref dimension, compute the vector comparison mask: -// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] -// CHECK: llvm.shufflevector %{{.*}} : vector<17xi32> -// CHECK: %[[mask_b:.*]] = arith.cmpi slt, {{.*}} : vector<17xi32> +// 3. Create bound vector to compute in-bound mask: +// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] +// CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]] : index to i32 +// CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]] +// CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]] +// CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]], +// CHECK-SAME: %[[boundVect_b]] : vector<17xi32> // // 4. Bitcast to vector form. // CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} : @@ -1344,16 +1346,20 @@ // CHECK: %[[c1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref // -// Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. -// CHECK: %[[trunc:.*]] = arith.index_cast %[[BASE_1]] : index to i32 -// CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[trunc]] -// CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]] +// Compute the in-bound index (dim - offset) +// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index +// +// Create a vector with linear indices [ 0 .. vector_length - 1 ]. +// CHECK: %[[linearIndex:.*]] = arith.constant dense +// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : +// CHECK-SAME: vector<17xi32> // -// Let dim the memref dimension, compute the vector comparison mask: -// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] -// CHECK: %[[dimtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32 -// CHECK: %[[dimtruncInsert:.*]] = llvm.insertelement %[[dimtrunc]] -// CHECK: llvm.shufflevector %[[dimtruncInsert]] +// Create bound vector to compute in-bound mask: +// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ] +// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32 +// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]] +// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]] +// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] // -----