diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" @@ -191,19 +192,23 @@ LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type vectorType = - getTypeConverter()->convertType(adaptor.getVector().getType()); - if (!vectorType) + Type resultType = getTypeConverter()->convertType(extractOp.getType()); + if (!resultType) return failure(); - if (vectorType.isa()) { + if (adaptor.getVector().getType().isa()) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), adaptor.getVector(), - extractOp.getPosition()); + APInt cstPos; + if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) + rewriter.replaceOpWithNewOp( + extractOp, resultType, adaptor.getVector(), + rewriter.getI32ArrayAttr({static_cast(cstPos.getSExtValue())})); + else + rewriter.replaceOpWithNewOp( + extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); return success(); } }; @@ -224,9 +229,15 @@ return success(); } - rewriter.replaceOpWithNewOp( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - insertOp.getPosition()); + APInt cstPos; + if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) + rewriter.replaceOpWithNewOp( + insertOp, adaptor.getSource(), adaptor.getDest(), + cstPos.getSExtValue()); + else + rewriter.replaceOpWithNewOp( + insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), + adaptor.getPosition()); return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -187,9 +187,20 @@ // ----- +// CHECK-LABEL: @extract_element_cst +// CHECK-SAME: %[[V:.*]]: vector<4xf32> +// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> +func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { + %idx = arith.constant 1 : i32 + %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> + return %0: f32 +} + +// ----- + // CHECK-LABEL: @extract_element_index func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: vector.extractelement + // CHECK: spirv.VectorExtractDynamic %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> return %0: f32 } @@ -249,9 +260,20 @@ // ----- +// CHECK-LABEL: @insert_element_cst +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> +// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> +func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { + %idx = arith.constant 2 : i32 + %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + // CHECK-LABEL: @insert_element_index func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: vector.insertelement + // CHECK: spirv.VectorInsertDynamic %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> return %0: vector<4xf32> }