diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1303,7 +1303,8 @@ static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) { // TODO: Masking only supports dynamic generic ops for now. - if (!isa(op)) + if (!isa(op.getOperation())) return failure(); LDBG("Dynamically-shaped op meets vectorization pre-conditions\n"); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -2832,7 +2832,7 @@ // CHECK-LABEL: func @test_masked_vectorize_pad func.func @test_masked_vectorize_pad( - %0 : tensor, %h0 : index, %h1 : index) + %0 : tensor, %h0 : index, %h1 : index) -> tensor<2x4xf32> { // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index @@ -2841,9 +2841,9 @@ // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> - // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { - // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]] - // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> + // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { + // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]] + // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0]], %[[c0]]] // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32> @@ -2857,7 +2857,47 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.masked_vectorize %0 vector_sizes [2, 4] } + +// ----- + +func.func @vectorize_dynamic_matmul(%A: memref, %B: memref, %C: memref) { + linalg.matmul ins(%A, %B: memref, memref) + outs(%C: memref) + return +} + +// CHECK-LABEL: func.func @vectorize_dynamic_matmul( +// CHECK-SAME: %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref, %[[VAL_2:.*]]: memref) { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[VAL_0]], %[[VAL_7]] : memref +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1> +// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_10]] {in_bounds = [true, true, true], permutation_map = #map} : memref, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32> +// CHECK: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_14:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1> +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_14]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_13]] {in_bounds = [true, true, true], permutation_map = #map1} : memref, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32> +// CHECK: %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_17:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1> +// CHECK: %[[VAL_18:.*]] = vector.mask %[[VAL_17]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_16]] {in_bounds = [true, true]} : memref, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32> +// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_12]], %[[VAL_15]] : vector<8x16x4xf32> +// CHECK: %[[VAL_20:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1> +// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.multi_reduction , %[[VAL_19]], %[[VAL_18]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32> +// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index +// CHECK: vector.mask %[[VAL_17]] { vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_22]], %[[VAL_22]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref } : vector<8x16xi1> +// CHECK: return +// CHECK: } + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] +}