diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -457,50 +457,37 @@ def Vector_ExtractMapOp : Vector_Op<"extract_map", [NoSideEffect]>, - Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>, + Arguments<(ins AnyVector:$vector, Index:$id)>, Results<(outs AnyVector)> { let summary = "vector extract map operation"; let description = [{ - Takes an 1-D vector and extract a sub-part of the vector starting at id with - a size of `vector size / multiplicity`. This maps a given multiplicity of - the vector to a Value such as a loop induction variable or an SPMD id. + Takes an 1-D vector and extracts a sub-part of the vector starting at id. Similarly to vector.tuple_get, this operation is used for progressive lowering and should be folded away before converting to LLVM. + It is different than `vector.extract_slice` and + `vector.extract_strided_slice` as it takes a Value as index instead of an + attribute. Also in the future it is meant to support extracting along any + dimensions and not only the most major ones. - For instance, the following code: - ```mlir - %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32> - %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32> - %c = addf %a, %b: vector<32xf32> - vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32> - ``` - can be rewritten to: - ```mlir - %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32> - %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32> - %ea = vector.extract_map %a[%id : 32] : vector<32xf32> to vector<1xf32> - %eb = vector.extract_map %b[%id : 32] : vector<32xf32> to vector<1xf32> - %ec = addf %ea, %eb : vector<1xf32> - %c = vector.insert_map %ec, %id, 32 : vector<1xf32> to vector<32xf32> - vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32> + For instance: ``` - - Where %id can be an induction variable or an SPMD id going from 0 to 31. - - And then be rewritten to: - ```mlir - %a = vector.transfer_read %A[%id]: memref<32xf32>, vector<1xf32> - %b = vector.transfer_read %B[%id]: memref<32xf32>, vector<1xf32> - %c = addf %a, %b: vector<1xf32> - vector.transfer_write %c, %C[%id]: memref<32xf32>, vector<1xf32> + // dynamic computation producing the value 0 of index type + %idx0 = ... : index + // dynamic computation producing the value 1 of index type + %idx1 = ... : index + %0 = constant dense<0, 1, 2, 3>: vector<4xi32> + // extracts values [0, 1] + %1 = vector.extract_map %0[%idx0] : vector<4xi32> to vector<2xi32> + // extracts values [1, 2] + %2 = vector.extract_map %0[%idx1] : vector<4xi32> to vector<2xi32> ``` Example: ```mlir - %ev = vector.extract_map %v[%id:32] : vector<32xf32> to vector<1xf32> + %ev = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32> ``` }]; let builders = [OpBuilder< @@ -512,10 +499,13 @@ VectorType getResultType() { return getResult().getType().cast(); } + int64_t multiplicity() { + return getSourceVectorType().getNumElements() / + getResultType().getNumElements(); + } }]; let assemblyFormat = [{ - $vector `[` $id `:` $multiplicity `]` attr-dict `:` type($vector) `to` - type(results) + $vector `[` $id `]` attr-dict `:` type($vector) `to` type(results) }]; let hasFolder = 1; @@ -694,30 +684,48 @@ } def Vector_InsertMapOp : - Vector_Op<"insert_map", [NoSideEffect]>, - Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>, - Results<(outs AnyVector)> { + Vector_Op<"insert_map", [NoSideEffect, AllTypesMatch<["dest", "result"]>]>, + Arguments<(ins AnyVector:$vector, AnyVector:$dest, Index:$id)>, + Results<(outs AnyVector:$result)> { let summary = "vector insert map operation"; let description = [{ - insert an 1-D vector and within a larger vector starting at id. The new - vector created will have a size of `vector size * multiplicity`. This - represents how a sub-part of the vector is written for a given Value such as - a loop induction variable or an SPMD id. + Inserts a 1-D vector and within a larger vector starting at id. The new + vector created will have the same size as the destination operand vector. Similarly to vector.tuple_get, this operation is used for progressive lowering and should be folded away before converting to LLVM. + It is different than `vector.insert` and `vector.insert_strided_slice` as it + takes a Value as index instead of an attribute. Also in the future it is + meant to support inserting along any dimensions and not only the most major + ones. + This operations is meant to be used in combination with vector.extract_map. - See example in extract.map description. + For instance: + ``` + // dynamic computation producing the value 0 of index type + %idx0 = ... : index + // dynamic computation producing the value 1 of index type + %idx1 = ... : index / + %0 = constant dense<0, 1, 2, 3>: vector<4xi32> + // extracts values [0, 1] + %1 = vector.extract_map %0[%idx0] : vector<4xi32> to vector<2xi32> + // extracts values [1, 2] + %2 = vector.extract_map %0[%idx1] : vector<4xi32> to vector<2xi32> + // insert [0, 1] into [x, x, x, x] and produce [0, 1, x, x] + %3 = vector.insert_map %1, %0[%idx0] : vector<2xi32> into vector<4xi32> + // insert [1, 2] into [x, x, x, x] and produce [x, 1, 2, x] + %4 = vector.insert_map %2, %0[%idx1] : vector<2xi32> into vector<4xi32> + ``` Example: ```mlir - %v = vector.insert_map %ev, %id, 32 : vector<1xf32> to vector<32xf32> + %v = vector.insert_map %ev %v[%id] : vector<1xf32> into vector<32xf32> ``` }]; let builders = [OpBuilder< - "Value vector, Value id, int64_t multiplicity">]; + "Value vector, Value dest, Value id, int64_t multiplicity">]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { return vector().getType().cast(); @@ -725,10 +733,14 @@ VectorType getResultType() { return getResult().getType().cast(); } + int64_t multiplicity() { + return getResultType().getNumElements() / + getSourceVectorType().getNumElements(); + } }]; let assemblyFormat = [{ - $vector `,` $id `,` $multiplicity attr-dict `:` type($vector) `to` - type(results) + $vector `,` $dest `[` $id `]` attr-dict + `:` type($vector) `into` type($result) }]; } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -999,17 +999,18 @@ VectorType type = vector.getType().cast(); VectorType resultType = VectorType::get(type.getNumElements() / multiplicity, type.getElementType()); - ExtractMapOp::build(builder, result, resultType, vector, id, multiplicity); + ExtractMapOp::build(builder, result, resultType, vector, id); } static LogicalResult verify(ExtractMapOp op) { if (op.getSourceVectorType().getShape().size() != 1 || op.getResultType().getShape().size() != 1) return op.emitOpError("expects source and destination vectors of rank 1"); - if (op.getResultType().getNumElements() * (int64_t)op.multiplicity() != - op.getSourceVectorType().getNumElements()) - return op.emitOpError("vector sizes mismatch. Source size must be equal " - "to destination size * multiplicity"); + if (op.getSourceVectorType().getNumElements() % + op.getResultType().getNumElements() != + 0) + return op.emitOpError( + "source vector size must be a multiple of destination vector size"); return success(); } @@ -1248,22 +1249,23 @@ //===----------------------------------------------------------------------===// void InsertMapOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value id, int64_t multiplicity) { + Value vector, Value dest, Value id, + int64_t multiplicity) { VectorType type = vector.getType().cast(); VectorType resultType = VectorType::get(type.getNumElements() * multiplicity, type.getElementType()); - InsertMapOp::build(builder, result, resultType, vector, id, multiplicity); + InsertMapOp::build(builder, result, resultType, vector, dest, id); } static LogicalResult verify(InsertMapOp op) { if (op.getSourceVectorType().getShape().size() != 1 || op.getResultType().getShape().size() != 1) return op.emitOpError("expected source and destination vectors of rank 1"); - if ((int64_t)op.multiplicity() * op.getSourceVectorType().getNumElements() != - op.getResultType().getNumElements()) + if (op.getResultType().getNumElements() % + op.getSourceVectorType().getNumElements() != + 0) return op.emitOpError( - "vector sizes mismatch. Destination size must be equal " - "to source size * multiplicity"); + "destination vector size must be a multiple of source vector size"); return success(); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2507,8 +2507,8 @@ DistributeOps ops; ops.extract = builder.create(loc, result, id, multiplicity); - ops.insert = - builder.create(loc, ops.extract, id, multiplicity); + ops.insert = builder.create(loc, ops.extract, result, id, + multiplicity); return ops; } @@ -2532,8 +2532,10 @@ Value newRead = vector_transfer_read(extract.getType(), read.memref(), indices, read.permutation_map(), read.padding(), ArrayAttr()); + Value dest = rewriter.create( + read.getLoc(), read.getType(), rewriter.getZeroAttr(read.getType())); newRead = rewriter.create( - read.getLoc(), newRead, extract.id(), extract.multiplicity()); + read.getLoc(), newRead, dest, extract.id(), extract.multiplicity()); rewriter.replaceOp(read, newRead); return success(); } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1333,26 +1333,26 @@ func @extract_map_rank(%v: vector<2x32xf32>, %id : index) { // expected-error@+1 {{'vector.extract_map' op expects source and destination vectors of rank 1}} - %0 = vector.extract_map %v[%id : 32] : vector<2x32xf32> to vector<2x1xf32> + %0 = vector.extract_map %v[%id] : vector<2x32xf32> to vector<2x1xf32> } // ----- func @extract_map_size(%v: vector<63xf32>, %id : index) { - // expected-error@+1 {{'vector.extract_map' op vector sizes mismatch. Source size must be equal to destination size * multiplicity}} - %0 = vector.extract_map %v[%id : 32] : vector<63xf32> to vector<2xf32> + // expected-error@+1 {{'vector.extract_map' op source vector size must be a multiple of destination vector size}} + %0 = vector.extract_map %v[%id] : vector<63xf32> to vector<2xf32> } // ----- -func @insert_map_rank(%v: vector<2x1xf32>, %id : index) { +func @insert_map_rank(%v: vector<2x1xf32>, %v1: vector<2x32xf32>, %id : index) { // expected-error@+1 {{'vector.insert_map' op expected source and destination vectors of rank 1}} - %0 = vector.insert_map %v, %id, 32 : vector<2x1xf32> to vector<2x32xf32> + %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<2x32xf32> } // ----- -func @insert_map_size(%v: vector<1xf32>, %id : index) { - // expected-error@+1 {{'vector.insert_map' op vector sizes mismatch. Destination size must be equal to source size * multiplicity}} - %0 = vector.insert_map %v, %id, 32 : vector<1xf32> to vector<64xf32> +func @insert_map_size(%v: vector<3xf32>, %v1: vector<64xf32>, %id : index) { + // expected-error@+1 {{'vector.insert_map' op destination vector size must be a multiple of source vector size}} + %0 = vector.insert_map %v, %v1[%id] : vector<3xf32> into vector<64xf32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -435,10 +435,10 @@ // CHECK-LABEL: @extract_insert_map func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> { - // CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}} : 16] : vector<32xf32> to vector<2xf32> - %vd = vector.extract_map %v[%id : 16] : vector<32xf32> to vector<2xf32> - // CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}, 16 : vector<2xf32> to vector<32xf32> - %r = vector.insert_map %vd, %id, 16 : vector<2xf32> to vector<32xf32> + // CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}}] : vector<32xf32> to vector<2xf32> + %vd = vector.extract_map %v[%id] : vector<32xf32> to vector<2xf32> + // CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}[%{{.*}}] : vector<2xf32> into vector<32xf32> + %r = vector.insert_map %vd, %v[%id] : vector<2xf32> into vector<32xf32> // CHECK: return %[[R]] : vector<32xf32> return %r : vector<32xf32> } diff --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir --- a/mlir/test/Dialect/Vector/vector-distribution.mlir +++ b/mlir/test/Dialect/Vector/vector-distribution.mlir @@ -2,10 +2,11 @@ // CHECK-LABEL: func @distribute_vector_add // CHECK-SAME: (%[[ID:.*]]: index -// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32> -// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[ADDV:.*]] = addf %{{.*}}, %{{.*}} : vector<32xf32> +// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32> // CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32> -// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ID]], 32 : vector<1xf32> to vector<32xf32> +// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32> // CHECK-NEXT: return %[[INS]] : vector<32xf32> func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> { %0 = addf %A, %B : vector<32xf32> diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -176,7 +176,7 @@ Optional ops = distributPointwiseVectorOp( builder, op.getOperation(), func.getArgument(0), multiplicity); if (ops.hasValue()) { - SmallPtrSet extractOp({ops->extract}); + SmallPtrSet extractOp({ops->extract, ops->insert}); op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); } });