diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1419,7 +1419,7 @@ Arguments<(ins AnyMemRef:$base, VectorOfRankAndType<[1], [AnyInteger]>:$indices, VectorOfRankAndType<[1], [I1]>:$mask, - Variadic>:$pass_thru)>, + VectorOfRank<[1]>:$pass_thru)>, Results<(outs VectorOfRank<[1]>:$result)> { let summary = "gathers elements from memory into a vector as defined by an index vector and mask"; @@ -1428,10 +1428,8 @@ The gather operation gathers elements from memory into a 1-D vector as defined by a base and a 1-D index vector, but only if the corresponding bit is set in a 1-D mask vector. Otherwise, the element is taken from a - 1-D pass-through vector, if provided, or left undefined. Informally the - semantics are: + 1-D pass-through vector. Informally the semantics are: ``` - if (!defined(pass_thru)) pass_thru = [undef, .., undef] result[0] := mask[0] ? base[index[0]] : pass_thru[0] result[1] := mask[1] ? base[index[1]] : pass_thru[1] etc. @@ -1447,8 +1445,8 @@ Example: ```mlir - %g = vector.gather %base, %indices, %mask, %pass_thru - : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + %g = vector.gather %base[%indices], %mask, %pass_thru + : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` }]; let extraClassDeclaration = [{ @@ -1462,15 +1460,14 @@ return mask().getType().cast(); } VectorType getPassThruVectorType() { - return (llvm::size(pass_thru()) == 0) - ? VectorType() - : (*pass_thru().begin()).getType().cast(); + return pass_thru().getType().cast(); } VectorType getResultVectorType() { return result().getType().cast(); } }]; - let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; + let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " + "type($base) `,` type($indices) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; } @@ -1507,8 +1504,8 @@ Example: ```mlir - vector.scatter %base, %indices, %mask, %value - : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref + vector.scatter %base[%indices], %mask, %value + : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> ``` }]; let extraClassDeclaration = [{ @@ -1525,8 +1522,8 @@ return value().getType().cast(); } }]; - let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` " - "type($indices) `,` type($mask) `,` type($value) `into` type($base)"; + let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` " + "type($base) `,` type($indices) `,` type($mask) `,` type($value)"; let hasCanonicalizer = 1; } diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir @@ -3,18 +3,10 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s -func @gather8(%base: memref, - %indices: vector<8xi32>, %mask: vector<8xi1>) -> vector<8xf32> { - %g = vector.gather %base, %indices, %mask - : (memref, vector<8xi32>, vector<8xi1>) -> vector<8xf32> - return %g : vector<8xf32> -} - -func @gather_pass_thru8(%base: memref, - %indices: vector<8xi32>, %mask: vector<8xi1>, - %pass_thru: vector<8xf32>) -> vector<8xf32> { - %g = vector.gather %base, %indices, %mask, %pass_thru - : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> vector<8xf32> +func @gather8(%base: memref, %indices: vector<8xi32>, + %mask: vector<8xi1>, %pass_thru: vector<8xf32>) -> vector<8xf32> { + %g = vector.gather %base[%indices], %mask, %pass_thru + : memref, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32> return %g : vector<8xf32> } @@ -63,31 +55,31 @@ // Gather tests. // - %g1 = call @gather8(%A, %idx, %all) - : (memref, vector<8xi32>, vector<8xi1>) + %g1 = call @gather8(%A, %idx, %all, %pass) + : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> (vector<8xf32>) vector.print %g1 : vector<8xf32> // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 ) - %g2 = call @gather_pass_thru8(%A, %idx, %none, %pass) + %g2 = call @gather8(%A, %idx, %none, %pass) : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> (vector<8xf32>) vector.print %g2 : vector<8xf32> // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7 ) - %g3 = call @gather_pass_thru8(%A, %idx, %some, %pass) + %g3 = call @gather8(%A, %idx, %some, %pass) : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> (vector<8xf32>) vector.print %g3 : vector<8xf32> // CHECK: ( 0, 6, 1, 3, -7, -7, -7, -7 ) - %g4 = call @gather_pass_thru8(%A, %idx, %more, %pass) + %g4 = call @gather8(%A, %idx, %more, %pass) : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> (vector<8xf32>) vector.print %g4 : vector<8xf32> // CHECK: ( 0, 6, 1, 3, -7, -7, -7, 2 ) - %g5 = call @gather_pass_thru8(%A, %idx, %all, %pass) + %g5 = call @gather8(%A, %idx, %all, %pass) : (memref, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> (vector<8xf32>) vector.print %g5 : vector<8xf32> diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir @@ -6,8 +6,8 @@ func @scatter8(%base: memref, %indices: vector<8xi32>, %mask: vector<8xi1>, %value: vector<8xf32>) { - vector.scatter %base, %indices, %mask, %value - : vector<8xi32>, vector<8xi1>, vector<8xf32> into memref + vector.scatter %base[%indices], %mask, %value + : memref, vector<8xi32>, vector<8xi1>, vector<8xf32> return } diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir @@ -60,11 +60,12 @@ %cn = constant 8 : index %f0 = constant 0.0 : f32 %mask = vector.constant_mask [4] : vector<4xi1> + %pass = vector.broadcast %f0 : f32 to vector<4xf32> scf.for %i = %c0 to %cn step %c1 { %aval = load %AVAL[%i] : memref<8xvector<4xf32>> %aidx = load %AIDX[%i] : memref<8xvector<4xi32>> - %0 = vector.gather %X, %aidx, %mask - : (memref, vector<4xi32>, vector<4xi1>) -> vector<4xf32> + %0 = vector.gather %X[%aidx], %mask, %pass + : memref, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> %1 = vector.contract #dot_trait %aval, %0, %f0 : vector<4xf32>, vector<4xf32> into f32 store %1, %B[%i] : memref } diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir --- a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir +++ b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir @@ -50,12 +50,15 @@ %c0 = constant 0 : index %c1 = constant 1 : index %cn = constant 4 : index + %f0 = constant 0.0 : f32 %mask = vector.constant_mask [8] : vector<8xi1> + %pass = vector.broadcast %f0 : f32 to vector<8xf32> %b = load %B[%c0] : memref<1xvector<8xf32>> %b_out = scf.for %k = %c0 to %cn step %c1 iter_args(%b_iter = %b) -> (vector<8xf32>) { %aval = load %AVAL[%k] : memref<4xvector<8xf32>> %aidx = load %AIDX[%k] : memref<4xvector<8xi32>> - %0 = vector.gather %X, %aidx, %mask : (memref, vector<8xi32>, vector<8xi1>) -> vector<8xf32> + %0 = vector.gather %X[%aidx], %mask, %pass + : memref, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32> %b_new = vector.fma %aval, %0, %b_iter : vector<8xf32> scf.yield %b_new : vector<8xf32> } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2446,11 +2446,8 @@ return op.emitOpError("expected result dim to match indices dim"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected result dim to match mask dim"); - if (llvm::size(op.pass_thru()) != 0) { - VectorType passVType = op.getPassThruVectorType(); - if (resVType != passVType) - return op.emitOpError("expected pass_thru of same type as result type"); - } + if (resVType != op.getPassThruVectorType()) + return op.emitOpError("expected pass_thru of same type as result type"); return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1096,7 +1096,7 @@ // CHECK: llvm.return func @gather_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { - %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> + %0 = vector.gather %arg0[%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> return %0 : vector<3xf32> } @@ -1106,7 +1106,7 @@ // CHECK: llvm.return %[[G]] : !llvm.vec<3 x f32> func @scatter_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { - vector.scatter %arg0, %arg1, %arg2, %arg3 : vector<3xi32>, vector<3xi1>, vector<3xf32> into memref + vector.scatter %arg0[%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> return } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1238,65 +1238,83 @@ // ----- -func @gather_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { +func @gather_base_type_mismatch(%base: memref, %indices: vector<16xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op base and result element type should match}} - %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> + %0 = vector.gather %base[%indices], %mask, %pass_thru + : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> } // ----- -func @gather_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { +func @gather_rank_mismatch(%base: memref, %indices: vector<16xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}} - %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<2x16xf32> + %0 = vector.gather %base[%indices], %mask, %pass_thru + : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32> } // ----- -func @gather_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, %mask: vector<16xi1>) { +func @gather_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op expected result dim to match indices dim}} - %0 = vector.gather %base, %indices, %mask : (memref, vector<17xi32>, vector<16xi1>) -> vector<16xf32> + %0 = vector.gather %base[%indices], %mask, %pass_thru + : memref, vector<17xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> } // ----- -func @gather_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<17xi1>) { +func @gather_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, + %mask: vector<17xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op expected result dim to match mask dim}} - %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<17xi1>) -> vector<16xf32> + %0 = vector.gather %base[%indices], %mask, %pass_thru + : memref, vector<16xi32>, vector<17xi1>, vector<16xf32> into vector<16xf32> } // ----- -func @gather_pass_thru_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf64>) { +func @gather_pass_thru_type_mismatch(%base: memref, %indices: vector<16xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf64>) { // expected-error@+1 {{'vector.gather' op expected pass_thru of same type as result type}} - %0 = vector.gather %base, %indices, %mask, %pass_thru : (memref, vector<16xi32>, vector<16xi1>, vector<16xf64>) -> vector<16xf32> + %0 = vector.gather %base[%indices], %mask, %pass_thru + : memref, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf32> } // ----- -func @scatter_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { +func @scatter_base_type_mismatch(%base: memref, %indices: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>) { // expected-error@+1 {{'vector.scatter' op base and value element type should match}} - vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref + vector.scatter %base[%indices], %mask, %value + : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> } // ----- -func @scatter_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) { +func @scatter_rank_mismatch(%base: memref, %indices: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<2x16xf32>) { // expected-error@+1 {{'vector.scatter' op operand #3 must be of ranks 1, but got 'vector<2x16xf32>'}} - vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<2x16xf32> into memref + vector.scatter %base[%indices], %mask, %value + : memref, vector<16xi32>, vector<16xi1>, vector<2x16xf32> } // ----- -func @scatter_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { +func @scatter_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>) { // expected-error@+1 {{'vector.scatter' op expected value dim to match indices dim}} - vector.scatter %base, %indices, %mask, %value : vector<17xi32>, vector<16xi1>, vector<16xf32> into memref + vector.scatter %base[%indices], %mask, %value + : memref, vector<17xi32>, vector<16xi1>, vector<16xf32> } // ----- -func @scatter_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) { +func @scatter_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, + %mask: vector<17xi1>, %value: vector<16xf32>) { // expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}} - vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref + vector.scatter %base[%indices], %mask, %value + : memref, vector<16xi32>, vector<17xi1>, vector<16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -461,21 +461,19 @@ } // CHECK-LABEL: @gather_and_scatter -func @gather_and_scatter(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>) { - // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> - %0 = vector.gather %base, %indices, %mask : (memref, vector<16xi32>, vector<16xi1>) -> vector<16xf32> - // CHECK: %[[Y:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}}, %[[X]] : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> - %1 = vector.gather %base, %indices, %mask, %0 : (memref, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> - // CHECK: vector.scatter %{{.*}}, %{{.*}}, %{{.*}}, %[[Y]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref - vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref +func @gather_and_scatter(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.gather %base[%indices], %mask, %pass_thru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: vector.scatter %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> + vector.scatter %base[%indices], %mask, %0 : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> return } // CHECK-LABEL: @expand_and_compress -func @expand_and_compress(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { +func @expand_and_compress(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = constant 0 : index // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - %0 = vector.expandload %base[%c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> vector.compressstore %base[%c0], %mask, %0 : memref, vector<16xi1>, vector<16xf32> return diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir @@ -86,12 +86,12 @@ // CHECK-SAME: %[[A1:.*]]: vector<16xi32>, // CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> { // CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1> -// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]], %[[A1]], %[[M]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> +// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK-NEXT: return %[[G]] : vector<16xf32> func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { %mask = vector.constant_mask [16] : vector<16xi1> - %ld = vector.gather %base, %indices, %mask, %pass_thru - : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + %ld = vector.gather %base[%indices], %mask, %pass_thru + : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } @@ -102,8 +102,8 @@ // CHECK-NEXT: return %[[A2]] : vector<16xf32> func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { %mask = vector.constant_mask [0] : vector<16xi1> - %ld = vector.gather %base, %indices, %mask, %pass_thru - : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32> + %ld = vector.gather %base[%indices], %mask, %pass_thru + : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } @@ -112,12 +112,12 @@ // CHECK-SAME: %[[A1:.*]]: vector<16xi32>, // CHECK-SAME: %[[A2:.*]]: vector<16xf32>) { // CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1> -// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[M]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32> +// CHECK-NEXT: vector.scatter %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> // CHECK-NEXT: return func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) { %mask = vector.constant_mask [16] : vector<16xi1> - vector.scatter %base, %indices, %mask, %value - : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32> + vector.scatter %base[%indices], %mask, %value + : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> return } @@ -129,8 +129,8 @@ func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) { %0 = vector.type_cast %base : memref<16xf32> to memref> %mask = vector.constant_mask [0] : vector<16xi1> - vector.scatter %base, %indices, %mask, %value - : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32> + vector.scatter %base[%indices], %mask, %value + : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> return }