diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1246,8 +1246,9 @@ ``` This operation always reads a slice starting at `%A[%expr1, %expr2, %expr3, - %expr4]`. The size of the slice is 3 along d2 and 5 along d0, so the slice - is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]` + %expr4]`. The size of the slice can be inferred from the resulting vector + shape and walking back through the permutation map: 3 along d2 and 5 along + d0, so the slice is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]` That slice needs to be read into a `vector<3x4x5xf32>`. Since the permutation map is not full rank, there must be a broadcast along vector @@ -1257,44 +1258,52 @@ ```mlir // %expr1, %expr2, %expr3, %expr4 defined before this point - %tmp = alloc() : vector<3x4x5xf32> - %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> + // alloc a temporary buffer for performing the "gather" of the slice. + %tmp = memref.alloc() : memref> for %i = 0 to 3 { affine.for %j = 0 to 4 { affine.for %k = 0 to 5 { - %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : - memref - store %tmp[%i, %j, %k] : vector<3x4x5xf32> + // Note that this load does not involve %j. + %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref + // Update the temporary gathered slice with the individual element + %slice = memref.load %tmp : memref> -> vector<3x4x5xf32> + %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32> + memref.store %updated, %temp : memref> }}} - %c0 = arith.constant 0 : index - %vec = load %view_in_tmp[%c0] : vector<3x4x5xf32> + // At this point we gathered the elements from the original + // memref into the desired vector layout, stored in the `%tmp` allocation. + %vec = memref.load %tmp : memref> -> vector<3x4x5xf32> ``` On a GPU one could then map `i`, `j`, `k` to blocks and threads. Notice that - the temporary storage footprint is `3 * 5` values but `3 * 4 * 5` values are - actually transferred between `%A` and `%tmp`. + the temporary storage footprint could conceptually be only `3 * 5` values but + `3 * 4 * 5` values are actually transferred between `%A` and `%tmp`. - Alternatively, if a notional vector broadcast operation were available, the - lowered code would resemble: + Alternatively, if a notional vector broadcast operation were available, we + could avoid the loop on `%j` and the lowered code would resemble: ```mlir // %expr1, %expr2, %expr3, %expr4 defined before this point - %tmp = alloc() : vector<3x4x5xf32> - %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> + %tmp = memref.alloc() : memref> for %i = 0 to 3 { affine.for %k = 0 to 5 { - %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : - memref - store %tmp[%i, 0, %k] : vector<3x4x5xf32> + %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref + %slice = memref.load %tmp : memref> -> vector<3x4x5xf32> + // Here we only store to the first element in dimension one + %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32> + memref.store %updated, %temp : memref> }} - %c0 = arith.constant 0 : index - %tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32> + // At this point we gathered the elements from the original + // memref into the desired vector layout, stored in the `%tmp` allocation. + // However we haven't replicated them alongside the first dimension, we need + // to broadcast now. + %partialVec = load %tmp : memref> -> vector<3x4x5xf32> %vec = broadcast %tmpvec, 1 : vector<3x4x5xf32> ``` where `broadcast` broadcasts from element 0 to all others along the - specified dimension. This time, the temporary storage footprint is `3 * 5` - values which is the same amount of data as the `3 * 5` values transferred. + specified dimension. This time, the number of loaded element is `3 * 5` + values. An additional `1` broadcast is required. On a GPU this broadcast could be implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`. @@ -1310,7 +1319,7 @@ // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32> // and pad with %f0 to handle the boundary case: %f0 = arith.constant 0.0f : f32 - for %i0 = 0 to %0 { + affine.for %i0 = 0 to %0 { affine.for %i1 = 0 to %1 step 256 { affine.for %i2 = 0 to %2 step 32 { %v = vector.transfer_read %A[%i0, %i1, %i2], (%f0) @@ -1320,7 +1329,7 @@ // or equivalently (rewrite with vector.transpose) %f0 = arith.constant 0.0f : f32 - for %i0 = 0 to %0 { + affine.for %i0 = 0 to %0 { affine.for %i1 = 0 to %1 step 256 { affine.for %i2 = 0 to %2 step 32 { %v0 = vector.transfer_read %A[%i0, %i1, %i2], (%f0) @@ -1333,7 +1342,7 @@ // Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into // vector<128xf32>. The underlying implementation will require a 1-D vector // broadcast: - for %i0 = 0 to %0 { + affine.for %i0 = 0 to %0 { affine.for %i1 = 0 to %1 { %3 = vector.transfer_read %A[%i0, %i1] {permutation_map: (d0, d1) -> (0)} :