diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -656,26 +656,6 @@ result.attributes.append(dictAttr.getValue().begin(), dictAttr.getValue().end()); - // Convert array of string into an array of IteratyType enums. This is needed, - // because tests still use the old format when 'iterator_types' attribute is - // represented as an array of strings. - // TODO: Remove this conversion once tests are fixed. - ArrayAttr iteratorTypes = llvm::cast( - result.attributes.get(getIteratorTypesAttrName(result.name))); - - SmallVector iteratorTypeAttrs; - - for (StringRef s : iteratorTypes.getAsValueRange()) { - auto maybeIteratorType = symbolizeIteratorType(s); - if (!maybeIteratorType.has_value()) - return parser.emitError(loc) << "unexpected iterator_type (" << s << ")"; - - iteratorTypeAttrs.push_back( - IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value())); - } - result.attributes.set(getIteratorTypesAttrName(result.name), - parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); - if (!result.attributes.get(getKindAttrName(result.name))) { result.addAttribute( getKindAttrName(result.name), @@ -705,22 +685,7 @@ traitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector attrs; for (auto attr : (*this)->getAttrs()) { - if (attr.getName() == getIteratorTypesAttrName()) { - auto iteratorTypes = - llvm::cast(attr.getValue()) - .getAsValueRange(); - // Convert IteratorType enums into the string representation. This is - // needed, because tests still use the old format when 'iterator_types' - // attribute is represented as an array of strings. - // TODO: Remove this conversion once tests are fixed. - SmallVector iteratorTypeNames = llvm::to_vector( - llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute { - return StringAttr::get(getContext(), stringifyIteratorType(t)); - })); - - attrs.emplace_back(getIteratorTypesAttrName(), - ArrayAttr::get(getContext(), iteratorTypeNames)); - } else if (traitAttrsSet.count(attr.getName().strref()) > 0) + if (traitAttrsSet.count(attr.getName().strref()) > 0) attrs.push_back(attr); } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -89,7 +89,7 @@ %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, #gpu.address_space>, vector<8x32xi8> %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32> // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32> // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}] // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}] @@ -154,7 +154,7 @@ %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64> %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64> // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64> // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]] // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]] @@ -192,7 +192,7 @@ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, #gpu.address_space>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, #gpu.address_space>, vector<8x16xf16> %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, #gpu.address_space>, vector<16x8xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, #gpu.address_space> return } @@ -237,7 +237,7 @@ // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16> - %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, #gpu.address_space> // CHECK-DAG: [[fragmentB1:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16> @@ -245,7 +245,7 @@ // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> %B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16> - %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B1, %C1 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B1, %C1 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, #gpu.address_space> return @@ -293,7 +293,7 @@ // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16> - %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> vector.transfer_write %D0, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<1x32x40xf16, #gpu.address_space> return @@ -330,7 +330,7 @@ // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[m_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, #gpu.address_space> -> vector<2x2xf16> %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, #gpu.address_space>, vector<16x8xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, #gpu.address_space> return } @@ -375,7 +375,7 @@ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, #gpu.address_space>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, #gpu.address_space>, vector<8x16xf16> %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, #gpu.address_space>, vector<16x8xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, #gpu.address_space> return } @@ -428,7 +428,7 @@ // CHECK-SAME: -> vector<2x2xf32> %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space>, vector<16x4xf32> %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space>, vector<8x4xf32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x4xf32>, vector<8x4xf32> into vector<16x8xf32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %cst_0 : vector<16x4xf32>, vector<8x4xf32> into vector<16x8xf32> // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32> // CHECK: affine.apply [[$rowC_map]] @@ -491,7 +491,7 @@ // CHECK-SAME: -> vector<2x2xf32> %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space>, vector<16x8xf32> %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space>, vector<8x8xf32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32> // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32> // CHECK: affine.apply [[$rowC_map]] @@ -561,7 +561,7 @@ // CHECK-SAME: -> vector<2x2xf32> %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<20x20xf32, #gpu.address_space>, vector<16x8xf32> %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf32, #gpu.address_space>, vector<8x8xf32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %cst_0 : vector<16x8xf32>, vector<8x8xf32> into vector<16x8xf32> // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32> @@ -629,7 +629,7 @@ %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi4, #gpu.address_space>, vector<8x64xi4> %C = vector.transfer_read %arg2[%c0, %c0], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32> // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x64xi4>, vector<8x64xi4> into vector<16x8xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x64xi4>, vector<8x64xi4> into vector<16x8xi32> // CHECK: [[lane:%.+]] = gpu.lane_id // CHECK: [[v:%.+]] = vector.extract [[d]][0] : vector<2x2xi32> @@ -699,7 +699,7 @@ %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xi8, #gpu.address_space>, vector<8x32xi8> %C = vector.transfer_read %arg2[%c0, %c0], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32> // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32> // CHECK: [[lane:%.+]] = gpu.lane_id // CHECK: [[v:%.+]] = vector.extract [[d]][0] : vector<2x2xi32> diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -20,7 +20,7 @@ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } @@ -47,7 +47,7 @@ %cst = arith.constant 0.000000e+00 : f16 %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } @@ -74,7 +74,7 @@ %cst = arith.constant 0.000000e+00 : f16 %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } @@ -106,7 +106,7 @@ %14 = scf.for %arg17 = %c0 to %c128 step %c32 iter_args(%arg18 = %C) -> (vector<16x16xf16>) { %17 = vector.transfer_read %arg0[%c0, %arg17], %cst {in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16> %18 = vector.transfer_read %arg1[%arg17, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16> - %19 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %17, %18, %arg18 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %19 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %17, %18, %arg18 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> scf.yield %19 : vector<16x16xf16> } vector.transfer_write %14, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<128x128xf16> @@ -139,7 +139,7 @@ %cst = arith.constant 0.000000e+00 : f16 %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> %E = arith.addf %D, %cst_1 : vector<16x16xf16> vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return @@ -170,7 +170,7 @@ %cst = arith.constant 0.000000e+00 : f16 %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>} : memref<16x16x16x16xf16>, vector<16x16xf16> @@ -202,7 +202,7 @@ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> return } @@ -230,7 +230,7 @@ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> return } @@ -257,7 +257,7 @@ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map5, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } @@ -284,7 +284,7 @@ %A = vector.transfer_read %arg0[%c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0) -> (d0, 0)>} : memref<16xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0) -> (d0, 0)>} : memref<16xf16>, vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } @@ -311,7 +311,7 @@ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, 0)>} : memref<32x32xf16>, vector<16x16xf16> %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, 0)>} : memref<32x32xf16>, vector<16x16xf16> %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } @@ -334,7 +334,7 @@ // CHECK-DAG: %[[A:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> // CHECK-DAG: %[[B:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> // CHECK-DAG: %[[C:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> -// CHECK: %[[D:.+]] = vector.contract {indexing_maps = [#[[$map]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32> +// CHECK: %[[D:.+]] = vector.contract {indexing_maps = [#[[$map]], #[[$map1]], #[[$map2]]], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> func.func @matmul_no_extend_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) { %cst_0 = arith.constant dense<0> : vector<16x16xi8> @@ -344,7 +344,7 @@ %A = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> %B = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8> %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %A, %B, %C : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> return } @@ -374,7 +374,7 @@ %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> %Ae = arith.extsi %Ar : vector<16x16xi8> to vector<16x16xi32> %Be = arith.extsi %Br : vector<16x16xi8> to vector<16x16xi32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> return } @@ -404,7 +404,7 @@ %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> %Ae = arith.extui %Ar : vector<16x16xi8> to vector<16x16xi32> %Be = arith.extsi %Br : vector<16x16xi8> to vector<16x16xi32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> return } @@ -434,7 +434,7 @@ %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32> %Ae = arith.extui %Ar : vector<16x32xi8> to vector<16x32xi32> %Be = arith.extsi %Br : vector<16x32xi8> to vector<16x32xi32> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %Ae, %Be, %C : vector<16x32xi32>, vector<16x32xi32> into vector<16x16xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %Ae, %Be, %C : vector<16x32xi32>, vector<16x32xi32> into vector<16x16xi32> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> return -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -38,14 +38,14 @@ /// w == 0, kw == 0 // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 1, kw == 0 // CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> @@ -104,7 +104,7 @@ // CHECK: %[[EXT_RHS_0:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32> // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[EXT_LHS_0]], %[[EXT_RHS_0]], %[[V_OUTPUT_0]] // CHECK-SAME: : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32> @@ -113,7 +113,7 @@ // CHECK: %[[EXT_RHS_1:.+]] = arith.extsi %[[V_FILTER]] : vector<3x8xi8> to vector<3x8xi32> // CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[EXT_LHS_1]], %[[EXT_RHS_1]], %[[V_OUTPUT_1]] // CHECK-SAME: : vector<4x1x3xi32>, vector<3x8xi32> into vector<4x1x8xi32> @@ -172,25 +172,25 @@ /// w == 0, kw == 0 // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 1, kw == 0 // CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 1, kw == 1 // CHECK: %[[CONTRACT_2:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 1, kw == 1 // CHECK: %[[CONTRACT_3:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> @@ -240,13 +240,13 @@ /// w == 0, kw == 0 // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] // CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> /// w == 0, kw == 1 // CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]] // CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> @@ -298,14 +298,14 @@ /// w == 0, kw == 0 // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 1, kw == 0 // CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> @@ -372,25 +372,25 @@ /// w == 0, kw == 0 // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 1, kw == 0 // CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 1, kw == 1 // CHECK: %[[CONTRACT_2:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> /// w == 1, kw == 1 // CHECK: %[[CONTRACT_3:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]] // CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32> @@ -448,13 +448,13 @@ /// w == 0, kw == 0 // CHECK: %[[CONTRACT_0:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]] // CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> /// w == 0, kw == 1 // CHECK: %[[CONTRACT_1:.+]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]] // CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32> diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -961,7 +961,7 @@ ] #contraction_trait0 = { indexing_maps = #contraction_accesses0, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @contractions diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -664,7 +664,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -683,7 +683,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -702,7 +702,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -721,7 +721,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -740,7 +740,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -759,7 +759,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -778,7 +778,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -797,7 +797,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -816,7 +816,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<88x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -835,7 +835,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, @@ -856,7 +856,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<4x3xi32>, %arg1: vector<3x7xf32>, @@ -875,7 +875,7 @@ ] #contraction_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) -> vector<3x2xf32> @@ -896,7 +896,7 @@ affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d1)> ], - iterator_types = ["reduction", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2xi32> into vector<1xi32> return %result : vector<1xi32> } @@ -1660,7 +1660,7 @@ // expected-error@+1 {{op only supports signless integer types}} %0 = vector.contract { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind } %arg0, %arg1, %arg2 : vector<16x32xsi8>, vector<32x16xsi8> into vector<16x16xsi32> return %0: vector<16x16xsi32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -278,13 +278,13 @@ ] #contraction_to_scalar_trait = { indexing_maps = #contraction_to_scalar_accesses, - iterator_types = ["reduction"] + iterator_types = [#vector.iterator_type] } // CHECK-LABEL: @contraction_to_scalar func.func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 %f0 = arith.constant 0.0: f32 - // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = [#vector.iterator_type], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0 : vector<10xf32>, vector<10xf32> into f32 // CHECK: return %[[X]] : f32 @@ -295,7 +295,7 @@ func.func @contraction_extra_attrs(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 %f0 = arith.constant 0.0: f32 - // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] {first_attr = 1 : i32, second_attr = "string"} : vector<10xf32>, vector<10xf32> into f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = [#vector.iterator_type], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] {first_attr = 1 : i32, second_attr = "string"} : vector<10xf32>, vector<10xf32> into f32 %0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0 {first_attr = 1 : i32, second_attr = "string"} : vector<10xf32>, vector<10xf32> into f32 @@ -310,14 +310,14 @@ ] #contraction_to_scalar_max_trait = { indexing_maps = #contraction_to_scalar_max_accesses, - iterator_types = ["reduction"], + iterator_types = [#vector.iterator_type], kind = #vector.kind } // CHECK-LABEL: @contraction_to_scalar_with_max func.func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 { // CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 %f0 = arith.constant 0.0: f32 - // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 + // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = [#vector.iterator_type], kind = #vector.kind} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32 %0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0 : vector<10xf32>, vector<10xf32> into f32 // CHECK: return %[[X]] : f32 @@ -331,7 +331,7 @@ ] #contraction_trait0 = { indexing_maps = #contraction_accesses0, - iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } #contraction_accesses1 = [ // 7, 8, 16, 15 affine_map<(f0, f1, f2, f3, c0, c1) -> (c0, f0, c1, f2)>, @@ -340,8 +340,8 @@ // 8, 8, 15, 5 affine_map<(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)> ] -#iterator_types1 = ["parallel", "parallel", "parallel", "parallel", "reduction", - "reduction"] +#iterator_types1 = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, + #vector.iterator_type] #contraction_trait1 = { indexing_maps = #contraction_accesses1, iterator_types = #iterator_types1 @@ -356,21 +356,21 @@ %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>, %arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) { // Test contraction with batch and contracting dims. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> // Test contraction with only contracting dims. In this case the lhs/rhs // dimension of size 8 will be considered a parallel dim for lhs/rhs and will // appear twice in the output. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> // Test contraction with mixed type. - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3 : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> // Test contraction with "max" instead of "add". - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> return diff --git a/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matmul-transforms.mlir @@ -8,7 +8,7 @@ %0 = vector.contract {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], - iterator_types = ["reduction"], + iterator_types = [#vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 return %0 : f32 @@ -23,7 +23,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } @@ -37,7 +37,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } @@ -55,7 +55,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %lhs, %rhs, %arg2 : vector<8x4xi32>, vector<4x16xi32> into vector<8x16xi32> return %res : vector<8x16xi32> } @@ -70,7 +70,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x16xi32>, vector<16x4xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } @@ -88,7 +88,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %lhs, %rhs, %arg2 : vector<8x4xi32>, vector<4x16xi32> into vector<8x16xi32> return %res : vector<8x16xi32> } @@ -102,7 +102,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } @@ -117,7 +117,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } @@ -136,7 +136,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %lhs, %rhs, %arg2 : vector<4x8xi32>, vector<4x16xi32> into vector<8x16xi32> return %res : vector<8x16xi32> } @@ -149,7 +149,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } @@ -164,7 +164,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d1, d0)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } @@ -178,7 +178,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d1, d0)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } @@ -192,7 +192,7 @@ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0)>], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32> return %res : vector<4x4xi32> } diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -7,11 +7,11 @@ ] #matvec_trait = { indexing_maps = #matvec_accesses, - iterator_types = ["parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #matvecmax_trait = { indexing_maps = #matvec_accesses, - iterator_types = ["parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type], kind = #vector.kind } @@ -22,7 +22,7 @@ ] #mattransvec_trait = { indexing_maps = #mattransvec_accesses, - iterator_types = ["parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #vecmat_accesses = [ @@ -32,7 +32,7 @@ ] #vecmat_trait = { indexing_maps = #vecmat_accesses, - iterator_types = ["parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #vecmattrans_accesses = [ @@ -42,7 +42,7 @@ ] #vecmattrans_trait = { indexing_maps = #vecmattrans_accesses, - iterator_types = ["parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #redpar_vecmattrans_accesses = [ @@ -52,7 +52,7 @@ ] #redpar_vecmattrans_trait = { indexing_maps = #redpar_vecmattrans_accesses, - iterator_types = ["reduction", "parallel"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @matvec2x2 diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir @@ -7,7 +7,7 @@ ] #dotp_trait = { indexing_maps = #dotp_accesses, - iterator_types = ["reduction"] + iterator_types = [#vector.iterator_type] } // CHECK-LABEL: func @extract_contract1 @@ -57,7 +57,7 @@ ] #matvec_trait = { indexing_maps = #matvec_accesses, - iterator_types = ["parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @extract_contract2 @@ -114,7 +114,7 @@ ] #vecmat_trait = { indexing_maps = #vecmat_accesses, - iterator_types = ["parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @extract_contract3 @@ -148,7 +148,7 @@ ] #matmat_trait = { indexing_maps = #matmat_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @extract_contract4 @@ -198,7 +198,7 @@ ] #contraction2d_trait = { indexing_maps = #contraction2d_accesses, - iterator_types = ["reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @full_contract1 @@ -230,7 +230,7 @@ ] #contraction2d_trans_trait = { indexing_maps = #contraction2d_trans_accesses, - iterator_types = ["reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @full_contract2 @@ -289,7 +289,7 @@ affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1)> ], - iterator_types = ["reduction", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32> return %res : vector<2xi32> diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir @@ -7,7 +7,7 @@ ] #matmat_trait = { indexing_maps = #matmat_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @matmul diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir @@ -7,7 +7,7 @@ ] #matvec_trait = { indexing_maps = #matvec_accesses, - iterator_types = ["parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #matmat_accesses = [ @@ -17,7 +17,7 @@ ] #matmat_trait = { indexing_maps = #matmat_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } #matmat_accesses_0 = [ @@ -27,7 +27,7 @@ ] #matmat_trait_0 = { indexing_maps = #matmat_accesses_0, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func.func @masked_extract_contract2( @@ -159,7 +159,7 @@ ] #matmat_trait_1 = { indexing_maps = #matmat_accesses_1, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @matmul_1 @@ -187,7 +187,7 @@ ] #matmat_trait_2 = { indexing_maps = #matmat_accesses_2, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @matmul_2 @@ -213,7 +213,7 @@ ] #matmat_trait_3 = { indexing_maps = #matmat_accesses_3, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @matmul_3 @@ -240,7 +240,7 @@ ] #matmat_trait_4 = { indexing_maps = #matmat_accesses_4, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @matmul_4 @@ -267,7 +267,7 @@ ] #matmat_trait_5 = { indexing_maps = #matmat_accesses_5, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @matmul_5 @@ -294,7 +294,7 @@ ] #matmat_trait_6 = { indexing_maps = #matmat_accesses_6, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @matmul_6 @@ -321,7 +321,7 @@ ] #matmat_trait_7 = { indexing_maps = #matmat_accesses_7, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } // CHECK-LABEL: func @matmul_7 diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir @@ -6,7 +6,7 @@ // CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> // CHECK: return %[[F]] : vector<4xf32> func.func @parallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> return %0 : vector<4xf32> } @@ -18,7 +18,7 @@ // CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> // CHECK: return %[[F]] : vector<4xf32> func.func @parallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> return %0 : vector<4xf32> } @@ -31,7 +31,7 @@ // CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32> // CHECK: return %[[F]] : vector<4xf32> func.func @parallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> return %0 : vector<4xf32> } @@ -46,7 +46,7 @@ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>], - iterator_types = ["reduction", "reduction"], kind = #vector.kind} + iterator_types = [#vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 return %0 : f32 } diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -9,7 +9,7 @@ // CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : vector<1x8x16xf32> // CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0] : vector<1x16x16xf32> // CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} // CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> // CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32> // CHECK-NEXT: return %[[R4]] : vector<1x16x16xf32> @@ -21,7 +21,7 @@ ] #contraction_trait0 = { indexing_maps = #contraction_accesses0, - iterator_types = ["parallel", "parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> { @@ -39,7 +39,7 @@ // CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32> // CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0, 0] : vector<1x1x16xf32> // CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], -// CHECK-SAME: iterator_types = ["parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type], kind = #vector.kind} // CHECK-SAME: %[[R1]], %[[R0]], %[[R2]] : vector<8xf32>, vector<8x16xf32> into vector<16xf32> // CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16xf32> to vector<1x16xf32> // CHECK-NEXT: %[[R5:.+]] = vector.broadcast %[[R4]] : vector<1x16xf32> to vector<1x1x16xf32> @@ -52,7 +52,7 @@ ] #contraction_trait1 = { indexing_maps = #contraction_accesses1, - iterator_types = ["parallel", "parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind } @@ -73,7 +73,7 @@ // CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<1x2x8xf32> // CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0] : vector<1x2x16xf32> // CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} // CHECK-SAME: %[[R1]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> // CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32> // CHECK-NEXT: return %[[R6]] : vector<1x2x16xf32> @@ -85,7 +85,7 @@ ] #contraction_trait2 = { indexing_maps = #contraction_accesses2, - iterator_types = ["parallel", "parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } @@ -109,7 +109,7 @@ // CHECK-NEXT: %[[R5:.+]] = vector.extract %[[R4]][0] : vector<1x2x8xf32> // CHECK-NEXT: %[[R6:.+]] = vector.extract %{{.*}}[0, 0] : vector<1x1x2x16xf32> // CHECK-NEXT: %[[R7:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} // CHECK-SAME: %[[R3]], %[[R5]], %[[R6]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> // CHECK-NEXT: %[[R8:.+]] = vector.broadcast %[[R7]] : vector<2x16xf32> to vector<1x2x16xf32> // CHECK-NEXT: %[[R9:.+]] = vector.broadcast %[[R8]] : vector<1x2x16xf32> to vector<1x1x2x16xf32> @@ -122,7 +122,7 @@ ] #contraction_trait2 = { indexing_maps = #contraction_accesses2, - iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type,#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } @@ -143,7 +143,7 @@ // CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R1]][0, 0] : vector<1x1x2x8xf32> // CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0, 0] : vector<1x1x2x16xf32> // CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} // CHECK-SAME: %[[R2]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> // CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32> // CHECK-NEXT: %[[R7:.+]] = vector.broadcast %[[R6]] : vector<1x2x16xf32> to vector<1x1x2x16xf32> @@ -156,7 +156,7 @@ ] #contraction_trait3 = { indexing_maps = #contraction_accesses3, - iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type,#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> { diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -6,7 +6,7 @@ // CHECK-LABEL: multidimreduction_contract // CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>) // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], -// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32> // CHECK-NEXT: return %[[R]] : vector<8x16xf32> func.func @multidimreduction_contract( @@ -24,7 +24,7 @@ // CHECK-LABEL: multidimreduction_contract_int // CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>) // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], -// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32> // CHECK-NEXT: return %[[R]] : vector<8x16xi32> func.func @multidimreduction_contract_int( @@ -47,7 +47,7 @@ // CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>, // CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} // CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32> // CHECK-NEXT: return %[[R]] : vector<8x32xf32> func.func @contract_transpose( @@ -55,7 +55,7 @@ %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32> %0 = vector.transpose %arg0, [2, 0, 1] : vector<32x16x8xf32> to vector<8x32x16xf32> %1 = vector.contract {indexing_maps = [#map0, #map0, #map1], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> return %1 : vector<8x32xf32> } @@ -73,7 +73,7 @@ // CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>, // CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} // CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> // CHECK-NEXT: return %[[R]] : vector<8x32xf32> func.func @contract_broadcast( @@ -81,7 +81,7 @@ %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32> %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> %1 = vector.contract {indexing_maps = [#map0, #map0, #map1], - iterator_types = ["parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> return %1 : vector<8x32xf32> } @@ -103,14 +103,14 @@ // CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>) // CHECK: vector.contract // CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32> func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> { %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32> %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32> %result = vector.contract { indexing_maps = [#map0, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32> return %result : vector<8x8xi32> @@ -137,14 +137,14 @@ // CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8x4xi32> to vector<2x8x4xi32> // CHECK: vector.contract // CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<2x8x4xi32>, vector<2x8x4xi32> into vector<8x8xi32> func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> { %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32> %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<2x8x4xi32> %result = vector.contract { indexing_maps = [#map0, #map1, #map2], - iterator_types = ["parallel", "reduction", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind } %0, %1, %arg2 : vector<2x8x4xi32>, vector<2x8x4xi32> into vector<8x8xi32> return %result : vector<8x8xi32> @@ -169,14 +169,14 @@ // CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8xi32> to vector<1x8xi32> // CHECK: vector.contract // CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] -// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x8xi32>, vector<1x8xi32> into vector<8x8xi32> func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> { %0 = vector.broadcast %arg0 : vector<8xi32> to vector<1x8xi32> %1 = vector.broadcast %arg1 : vector<8xi32> to vector<1x8xi32> %result = vector.contract { indexing_maps = [#map0, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind } %0, %1, %arg2 : vector<1x8xi32>, vector<1x8xi32> into vector<8x8xi32> return %result : vector<8x8xi32> @@ -200,14 +200,14 @@ // CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32> // CHECK: vector.contract // CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] -// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32> func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vector<1x2xi32>, %arg1 : vector<2xi32>, %arg2 : vector<1xi32>) -> vector<1xi32> { %1 = vector.broadcast %arg1 : vector<2xi32> to vector<1x1x2xi32> %result = vector.contract { indexing_maps = [#map0, #map1, #map2], - iterator_types = ["reduction", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind } %arg0, %1, %arg2 : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32> return %result : vector<1xi32> @@ -231,14 +231,14 @@ // CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<1xf32> to vector<1x1xf32> // CHECK: vector.contract // CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type] // CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1xf32>, vector<1x1xf32> into vector<1xf32> func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>, %arg2 : vector<1xf32>) -> vector<1xf32> { %1 = vector.broadcast %arg1 : vector<1xf32> to vector<1x1xf32> %result = vector.contract { indexing_maps = [#map0, #map1, #map2], - iterator_types = ["parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type], kind = #vector.kind } %arg0, %1, %arg2 : vector<1xf32>, vector<1x1xf32> into vector<1xf32> return %result : vector<1xf32> @@ -367,7 +367,7 @@ // CHECK-SAME: (%[[LHS:.+]]: vector<2x4x4xf32>, %[[RHS:.+]]: vector<4x8xf32>, %[[ACC:.+]]: vector<2x8x4xf32>) // CHECK: %[[CONTRACT:.+]] = vector.contract // CHECK-SAME: indexing_maps = [#[[$LHS_MAP]], #[[$RHS_MAP]], #[[$ACC_MAP]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type] // CHECK-SAME: kind = #vector.kind // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] // CHECK: return %[[CONTRACT]] @@ -379,7 +379,7 @@ affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> ], - iterator_types = ["parallel", "parallel", "parallel", "reduction"], + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind } %lhs, %rhs, %accT : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32> %resT = vector.transpose %contract, [0, 2, 1] : vector<2x4x8xf32> to vector<2x8x4xf32> diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -80,10 +80,10 @@ // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32> @@ -98,7 +98,7 @@ ] #contraction_trait1 = { indexing_maps = #contraction_accesses1, - iterator_types = ["parallel", "reduction", "parallel"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, @@ -219,10 +219,10 @@ // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32> // CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32> diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -8,7 +8,7 @@ {indexing_maps = [affine_map<(i, j, k) -> (i, k)>, affine_map<(i, j, k) -> (j, k)>, affine_map<(i, j, k) -> (i, j)>], - iterator_types = ["parallel", "parallel", "reduction"]} + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type]} %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32> return %0 : vector<8x8xf32> } @@ -158,7 +158,7 @@ {indexing_maps = [affine_map<(i, j, k) -> (i, k)>, affine_map<(i, j, k) -> (j, k)>, affine_map<(i, j, k) -> (i, j)>], - iterator_types = ["parallel", "parallel", "reduction"]} + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type]} %lhs, %rhs, %init : vector<8x8xf16>, vector<8x8xf16> into vector<8x8xf16> return %0 : vector<8x8xf16> } @@ -271,7 +271,7 @@ {indexing_maps = [affine_map<(d0,d1,d2,c0) -> (d0,d1,c0)>, affine_map<(d0,d1,d2,c0) -> (d0,d2,c0)>, affine_map<(d0,d1,d2,c0) -> (d0,d1,d2)>], - iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type, #vector.iterator_type]} %lhs, %rhs, %init : vector<8x8x4xf32>, vector<8x8x4xf32> into vector<8x8x8xf32> return %0 : vector<8x8x8xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-contraction.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-contraction.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-contraction.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-contraction.mlir @@ -10,7 +10,7 @@ ] #dotp_trait = { indexing_maps = #dotp_accesses, - iterator_types = ["reduction"] + iterator_types = [#vector.iterator_type] } #matvec_accesses = [ @@ -20,7 +20,7 @@ ] #matvec_trait = { indexing_maps = #matvec_accesses, - iterator_types = ["parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #mattransvec_accesses = [ @@ -30,7 +30,7 @@ ] #mattransvec_trait = { indexing_maps = #mattransvec_accesses, - iterator_types = ["parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #matmat_accesses = [ @@ -40,7 +40,7 @@ ] #matmat_trait = { indexing_maps = #matmat_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } #mattransmat_accesses = [ @@ -50,7 +50,7 @@ ] #mattransmat_trait = { indexing_maps = #mattransmat_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } #matmattrans_accesses = [ @@ -60,7 +60,7 @@ ] #matmattrans_trait = { indexing_maps = #matmattrans_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } #mattransmattrans_accesses = [ @@ -70,7 +70,7 @@ ] #mattransmattrans_trait = { indexing_maps = #mattransmattrans_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } #matmat_then_trans_accesses = [ @@ -80,7 +80,7 @@ ] #matmat_then_trans_trait = { indexing_maps = #matmat_then_trans_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } #contract2d_accesses = [ @@ -90,7 +90,7 @@ ] #contract2d_trait = { indexing_maps = #contract2d_accesses, - iterator_types = ["reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #contract2d_alt_accesses = [ @@ -100,7 +100,7 @@ ] #contract2d_alt_trait = { indexing_maps = #contract2d_alt_accesses, - iterator_types = ["reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #contract2d_trans_accesses = [ @@ -110,7 +110,7 @@ ] #contract2d_trans_trait = { indexing_maps = #contract2d_trans_accesses, - iterator_types = ["reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #contract2d_trans_alt_accesses = [ @@ -120,7 +120,7 @@ ] #contract2d_trans_alt_trait = { indexing_maps = #contract2d_trans_alt_accesses, - iterator_types = ["reduction", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type] } #column_major_matmat_accesses = [ @@ -130,7 +130,7 @@ ] #column_major_matmat_trait = { indexing_maps = #column_major_matmat_accesses, - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = [#vector.iterator_type, #vector.iterator_type, #vector.iterator_type] } func.func @entry() { diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir @@ -50,7 +50,7 @@ ] #dot_trait = { indexing_maps = #contraction_accesses, - iterator_types = ["reduction"] + iterator_types = [#vector.iterator_type] } func.func @spmv8x8(%AVAL: memref<8xvector<4xf32>>,