diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -126,6 +126,19 @@ if (!llvm::all_of(shapedTypes, [](auto t) { return t; })) return failure(); + // Return failure if some, but not all, are scalable vectors. + bool hasScalableVecTypes = false; + bool hasNonScalableVecTypes = false; + for (Type t : types) { + auto vType = t.dyn_cast(); + if (vType && vType.isScalable()) + hasScalableVecTypes = true; + else + hasNonScalableVecTypes = true; + if (hasScalableVecTypes && hasNonScalableVecTypes) + return failure(); + } + // Remove all unranked shapes auto shapes = llvm::to_vector<8>(llvm::make_filter_range( shapedTypes, [](auto shapedType) { return shapedType.hasRank(); })); diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir --- a/mlir/test/Dialect/Arithmetic/invalid.mlir +++ b/mlir/test/Dialect/Arithmetic/invalid.mlir @@ -553,3 +553,163 @@ %0 = arith.trunci %arg0 : i16 to i16 return } + +// ----- + +func @trunci_scalable_to_fl(%arg0 : vector<[4]xi32>) { + // expected-error@+1 {{'arith.trunci' op requires the same shape for all operands and results}} + %0 = arith.trunci %arg0 : vector<[4]xi32> to vector<4xi8> + return +} + +// ----- + +func @truncf_scalable_to_fl(%arg0 : vector<[4]xf64>) { + // expected-error@+1 {{'arith.truncf' op requires the same shape for all operands and results}} + %0 = arith.truncf %arg0 : vector<[4]xf64> to vector<4xf32> + return +} + +// ----- + +func @extui_scalable_to_fl(%arg0 : vector<[4]xi32>) { + // expected-error@+1 {{'arith.extui' op requires the same shape for all operands and results}} + %0 = arith.extui %arg0 : vector<[4]xi32> to vector<4xi64> + return +} + +// ----- + +func @extsi_scalable_to_fl(%arg0 : vector<[4]xi32>) { + // expected-error@+1 {{'arith.extsi' op requires the same shape for all operands and results}} + %0 = arith.extsi %arg0 : vector<[4]xi32> to vector<4xi64> + return +} + +// ----- + +func @extf_scalable_to_fl(%arg0 : vector<[4]xf32>) { + // expected-error@+1 {{'arith.extf' op requires the same shape for all operands and results}} + %0 = arith.extf %arg0 : vector<[4]xf32> to vector<4xf64> + return +} + +// ----- + +func @fptoui_scalable_to_fl(%arg0 : vector<[4]xf64>) { + // expected-error@+1 {{'arith.fptoui' op requires the same shape for all operands and results}} + %0 = arith.fptoui %arg0 : vector<[4]xf64> to vector<4xi32> + return +} + +// ----- + +func @fptosi_scalable_to_fl(%arg0 : vector<[4]xf32>) { + // expected-error@+1 {{'arith.fptosi' op requires the same shape for all operands and results}} + %0 = arith.fptosi %arg0 : vector<[4]xf32> to vector<4xi32> + return +} + +// ----- + +func @uitofp_scalable_to_fl(%arg0 : vector<[4]xi32>) { + // expected-error@+1 {{'arith.uitofp' op requires the same shape for all operands and results}} + %0 = arith.uitofp %arg0 : vector<[4]xi32> to vector<4xf32> + return +} + +// ----- + +func @sitofp_scalable_to_fl(%arg0 : vector<[4]xi32>) { + // expected-error@+1 {{'arith.sitofp' op requires the same shape for all operands and results}} + %0 = arith.sitofp %arg0 : vector<[4]xi32> to vector<4xf32> + return +} + +// ----- + +func @bitcast_scalable_to_fl(%arg0 : vector<[4]xf32>) { + // expected-error@+1 {{'arith.bitcast' op requires the same shape for all operands and results}} + %0 = arith.bitcast %arg0 : vector<[4]xf32> to vector<4xi32> + return +} + +// ----- + +func @trunci_fl_to_scalable(%arg0 : vector<4xi32>) { + // expected-error@+1 {{'arith.trunci' op requires the same shape for all operands and results}} + %0 = arith.trunci %arg0 : vector<4xi32> to vector<[4]xi8> + return +} + +// ----- + +func @truncf_fl_to_scalable(%arg0 : vector<4xf64>) { + // expected-error@+1 {{'arith.truncf' op requires the same shape for all operands and results}} + %0 = arith.truncf %arg0 : vector<4xf64> to vector<[4]xf32> + return +} + +// ----- + +func @extui_fl_to_scalable(%arg0 : vector<4xi32>) { + // expected-error@+1 {{'arith.extui' op requires the same shape for all operands and results}} + %0 = arith.extui %arg0 : vector<4xi32> to vector<[4]xi64> + return +} + +// ----- + +func @extsi_fl_to_scalable(%arg0 : vector<4xi32>) { + // expected-error@+1 {{'arith.extsi' op requires the same shape for all operands and results}} + %0 = arith.extsi %arg0 : vector<4xi32> to vector<[4]xi64> + return +} + +// ----- + +func @extf_fl_to_scalable(%arg0 : vector<4xf32>) { + // expected-error@+1 {{'arith.extf' op requires the same shape for all operands and results}} + %0 = arith.extf %arg0 : vector<4xf32> to vector<[4]xf64> + return +} + +// ----- + +func @fptoui_fl_to_scalable(%arg0 : vector<4xf64>) { + // expected-error@+1 {{'arith.fptoui' op requires the same shape for all operands and results}} + %0 = arith.fptoui %arg0 : vector<4xf64> to vector<[4]xi32> + return +} + +// ----- + +func @fptosi_fl_to_scalable(%arg0 : vector<4xf32>) { + // expected-error@+1 {{'arith.fptosi' op requires the same shape for all operands and results}} + %0 = arith.fptosi %arg0 : vector<4xf32> to vector<[4]xi32> + return +} + +// ----- + +func @uitofp_fl_to_scalable(%arg0 : vector<4xi32>) { + // expected-error@+1 {{'arith.uitofp' op requires the same shape for all operands and results}} + %0 = arith.uitofp %arg0 : vector<4xi32> to vector<[4]xf32> + return +} + +// ----- + +func @sitofp_fl_to_scalable(%arg0 : vector<4xi32>) { + // expected-error@+1 {{'arith.sitofp' op requires the same shape for all operands and results}} + %0 = arith.sitofp %arg0 : vector<4xi32> to vector<[4]xf32> + return +} + +// ----- + +func @bitcast_fl_to_scalable(%arg0 : vector<4xf32>) { + // expected-error@+1 {{'arith.bitcast' op requires the same shape for all operands and results}} + %0 = arith.bitcast %arg0 : vector<4xf32> to vector<[4]xi32> + return +}