diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -757,10 +757,31 @@ rewriter, read.getLoc(), d0 + scale * d1, {indices[indexPos], warpOp.getLaneid()}); } - Value newRead = rewriter.create( + auto newRead = rewriter.create( read.getLoc(), distributedVal.getType(), read.getSource(), indices, read.getPermutationMapAttr(), read.getPadding(), read.getMask(), read.getInBoundsAttr()); + + // Check that the produced operation is legal. + // The transfer op may be reading from values that are defined within + // warpOp's body, which is illegal. + // We do the check late because incdices may be changed by + // makeComposeAffineApply. This rewrite may remove dependencies from + // warOp's body. + // E.g., warop { + // %idx = affine.apply...[%outsideDef] + // ... = transfer_read ...[%idx] + // } + // will be rewritten in: + // warop { + // } + // %new_idx = affine.apply...[%outsideDef] + // ... = transfer_read ...[%new_idx] + if (!llvm::all_of(newRead->getOperands(), [&](Value value) { + return warpOp.isDefinedOutsideOfRegion(value); + })) + return failure(); + rewriter.replaceAllUsesWith(distributedVal, newRead); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1109,3 +1109,46 @@ } return %r : vector<4x96xf32> } + +// ----- + +// Check that we don't propagate transfer_reads that have dependencies on +// values inside the warp_execute_on_lane_0. +// In this case, propagating would create transfer_read that depends on the +// extractelment defined in the body. + +// CHECK-PROP-LABEL: func @transfer_read_no_prop( +// CHECK-PROP-SAME: %[[IN2:[^ :]*]]: vector<1x2xindex>, +// CHECK-PROP-SAME: %[[AR1:[^ :]*]]: memref<1x4x2xi32>, +// CHECK-PROP-SAME: %[[AR2:[^ :]*]]: memref<1x4x1024xf32>) +// CHECK-PROP-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-PROP-DAG: %[[THREADID:.*]] = gpu.thread_id x +// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%[[THREADID]])[32] args(%[[IN2]] +// CHECK-PROP: %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}] +// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<1x64xi32> +// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex> +// CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extractelement %[[CAST]][{{.*}}: i32] : vector<64xindex> +// CHECK-PROP: %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[EXTRACTELT]], %[[C0]]], +// CHECK-PROP: vector.yield %[[TRANSFERREAD]] : vector<64xf32> +// CHECK-PROP: return %[[W]] +func.func @transfer_read_no_prop(%in2: vector<1x2xindex>, %ar1 : memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> { + %0 = gpu.thread_id x + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<1x64xi32> + %cst_0 = arith.constant dense : vector<1x64xi1> + %cst_1 = arith.constant dense<3> : vector<64xindex> + %cst_2 = arith.constant dense<0> : vector<64xindex> + %cst_6 = arith.constant 0.000000e+00 : f32 + + %18 = vector.warp_execute_on_lane_0(%0)[32] args(%in2 : vector<1x2xindex>) -> (vector<2xf32>) { + ^bb0(%arg4: vector<1x64xindex>): + %28 = vector.gather %ar1[%c0, %c0, %c0] [%arg4], %cst_0, %cst : memref<1x4x2xi32>, vector<1x64xindex>, vector<1x64xi1>, vector<1x64xi32> into vector<1x64xi32> + %29 = vector.extract %28[0] : vector<1x64xi32> + %30 = arith.index_cast %29 : vector<64xi32> to vector<64xindex> + %36 = vector.extractelement %30[%c0_i32 : i32] : vector<64xindex> + %37 = vector.transfer_read %ar2[%c0, %36, %c0], %cst_6 {in_bounds = [true]} : memref<1x4x1024xf32>, vector<64xf32> + vector.yield %37 : vector<64xf32> + } + return %18 : vector<2xf32> +}