diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -169,21 +169,6 @@ return res; } -/// Return true if the scalable vector dimensions are supported. For now, we -/// only support scalable vectors in the trailing dimension. -static bool areValidScalableVecDims(ArrayRef scalableVecDims) { - if (scalableVecDims.empty()) - return true; - - auto isScalable = [](bool isScalableVecSize) { return isScalableVecSize; }; - if (std::any_of(scalableVecDims.begin(), scalableVecDims.end() - 1, - isScalable)) { - return false; - } - - return true; -} - /// Contains the vectorization state and related methods used across the /// vectorization process of a given operation. struct VectorizationState { @@ -217,12 +202,6 @@ scalableDims.append(scalableVecDims.begin(), scalableVecDims.end()); } - // Make sure we don't end up with unsupported scalable vector dimensions - // after the permutation. If so, we should bail out on that operation in the - // scalable preconditions. - assert(areValidScalableVecDims(scalableDims) && - "Permuted scalable vector dimensions are not supported"); - return VectorType::get(vectorShape, elementType, scalableDims); } @@ -1630,11 +1609,6 @@ if (inputVectorSizes.empty()) return success(); - if (!areValidScalableVecDims(inputScalableVecDims)) { - LDBG("Non-trailing scalable vector dimensions are not supported\n"); - return failure(); - } - bool isScalable = inputScalableVecDims.back(); if (!isScalable) return success(); diff --git a/mlir/test/Dialect/Linalg/vectorization-masked.mlir b/mlir/test/Dialect/Linalg/vectorization-masked.mlir --- a/mlir/test/Dialect/Linalg/vectorization-masked.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-masked.mlir @@ -447,41 +447,68 @@ // ----- -func.func @vectorize_dynamic_matmul(%A: memref, %B: memref, %C: memref) { +func.func @matmul(%A: memref, %B: memref, %C: memref) { linalg.matmul ins(%A, %B: memref, memref) outs(%C: memref) return } -// CHECK-LABEL: func.func @vectorize_dynamic_matmul( -// CHECK-SAME: %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref, %[[VAL_2:.*]]: memref) { +// CHECK-LABEL: func.func @matmul( +// CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref, %[[C:.*]]: memref) { // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref +// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[A]], %[[VAL_3]] : memref // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref +// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[B]], %[[VAL_5]] : memref // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[VAL_0]], %[[VAL_7]] : memref -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1> -// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_10]] {in_bounds = [true, true, true], permutation_map = #map} : memref, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32> -// CHECK: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_14:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1> -// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_14]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_13]] {in_bounds = [true, true, true], permutation_map = #map1} : memref, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32> -// CHECK: %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_17:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1> -// CHECK: %[[VAL_18:.*]] = vector.mask %[[VAL_17]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_16]] {in_bounds = [true, true]} : memref, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32> -// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_12]], %[[VAL_15]] : vector<8x16x4xf32> -// CHECK: %[[VAL_20:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1> -// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.multi_reduction , %[[VAL_19]], %[[VAL_18]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32> -// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index -// CHECK: vector.mask %[[VAL_17]] { vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_22]], %[[VAL_22]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref } : vector<8x16xi1> -// CHECK: return -// CHECK: } +// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[A]], %[[VAL_7]] : memref +// CHECK: %[[MASK_A:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1> +// CHECK: %[[LOAD_A:.*]] = vector.mask %[[MASK_A]] { vector.transfer_read %[[A]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32> +// CHECK: %[[MASK_B:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1> +// CHECK: %[[LOAD_B:.*]] = vector.mask %[[MASK_B]] { vector.transfer_read %[[B]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32> +// CHECK: %[[MASK_C:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1> +// CHECK: %[[LOAD_C:.*]] = vector.mask %[[MASK_C]] { vector.transfer_read %[[C]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32> +// CHECK: %[[MULF:.*]] = arith.mulf %[[LOAD_A]], %[[LOAD_B]] : vector<8x16x4xf32> +// CHECK: %[[MASK_MULIT_RED:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1> +// CHECK: %[[MULTI_RED:.*]] = vector.mask %[[MASK_MULIT_RED]] { vector.multi_reduction , %[[MULF]], %[[LOAD_C]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32> +// CHECK: %[[C2:.*]] = arith.constant 0 : index +// CHECK: vector.mask %[[MASK_C]] { vector.transfer_write %[[MULTI_RED]], %[[C]]{{\[}}%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref } : vector<8x16xi1> transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] : !transform.any_op + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize %matmul vector_sizes [8, 16, 4] : !transform.any_op +} + +// ----- + +func.func @matmul_scalable(%A: memref, %B: memref, %C: memref) { + linalg.matmul ins(%A, %B: memref, memref) + outs(%C: memref) + return } +// CHECK-LABEL: func.func @matmul_scalable( +// CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref, %[[C:.*]]: memref) { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[A]], %[[VAL_3]] : memref +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[B]], %[[VAL_5]] : memref +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[A]], %[[VAL_7]] : memref +// CHECK: %[[MASK_A:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1> +// CHECK: %[[LOAD_A:.*]] = vector.mask %[[MASK_A]] { vector.transfer_read %[[A]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref, vector<8x[16]x4xf32> } : vector<8x4xi1> -> vector<8x[16]x4xf32> +// CHECK: %[[MASK_B:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x[16]xi1> +// CHECK: %[[LOAD_B:.*]] = vector.mask %[[MASK_B]] { vector.transfer_read %[[B]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref, vector<8x[16]x4xf32> } : vector<4x[16]xi1> -> vector<8x[16]x4xf32> +// CHECK: %[[MASK_C:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x[16]xi1> +// CHECK: %[[LOAD_C:.*]] = vector.mask %[[MASK_C]] { vector.transfer_read %[[C]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref, vector<8x[16]xf32> } : vector<8x[16]xi1> -> vector<8x[16]xf32> +// CHECK: %[[MULF:.*]] = arith.mulf %[[LOAD_A]], %[[LOAD_B]] : vector<8x[16]x4xf32> +// CHECK: %[[MASK_MULIT_RED:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x[16]x4xi1> +// CHECK: %[[MULTI_RED:.*]] = vector.mask %[[MASK_MULIT_RED]] { vector.multi_reduction , %[[MULF]], %[[LOAD_C]] [2] : vector<8x[16]x4xf32> to vector<8x[16]xf32> } : vector<8x[16]x4xi1> -> vector<8x[16]xf32> +// CHECK: %[[C2:.*]] = arith.constant 0 : index +// CHECK: vector.mask %[[MASK_C]] { vector.transfer_write %[[MULTI_RED]], %[[C]]{{\[}}%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<8x[16]xf32>, memref } : vector<8x[16]xi1> + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.masked_vectorize %matmul vector_sizes [8, [16], 4] : !transform.any_op +}