diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s --test-vector-gather-lowering | FileCheck %s +// RUN: mlir-opt %s --test-vector-gather-lowering --canonicalize | FileCheck %s --check-prefix=CANON // CHECK-LABEL: @gather_memref_1d // CHECK-SAME: ([[BASE:%.+]]: memref, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>) @@ -125,3 +126,28 @@ %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : tensor, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> return %0 : vector<2x3xf32> } + +// Check that all-set and no-set maskes get optimized out after canonicalization. + +// CANON-LABEL: @gather_tensor_1d_all_set +// CANON-NOT: scf.if +// CANON: tensor.extract +// CANON: tensor.extract +// CANON: [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : f32 into vector<2xf32> +// CANON-NEXT: return [[FINAL]] : vector<2xf32> +func.func @gather_tensor_1d_all_set(%base: tensor, %v: vector<2xindex>, %pass_thru: vector<2xf32>) -> vector<2xf32> { + %mask = arith.constant dense : vector<2xi1> + %c0 = arith.constant 0 : index + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +// CANON-LABEL: @gather_tensor_1d_none_set +// CANON-SAME: ([[BASE:%.+]]: tensor, [[IDXVEC:%.+]]: vector<2xindex>, [[PASS:%.+]]: vector<2xf32>) +// CANON-NEXT: return [[PASS]] : vector<2xf32> +func.func @gather_tensor_1d_none_set(%base: tensor, %v: vector<2xindex>, %pass_thru: vector<2xf32>) -> vector<2xf32> { + %mask = arith.constant dense : vector<2xi1> + %c0 = arith.constant 0 : index + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32> + return %0 : vector<2xf32> +}