diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -784,6 +784,7 @@ TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, + TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern >(typeConverter, patterns.getContext()); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -65,6 +65,24 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.br to spv.Branch. +struct BranchOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(BranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts std.cond_br to spv.BranchConditional. +struct CondBranchOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CondBranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts tensor.extract into loading using access chains from SPIR-V local /// variables. class TensorExtractPattern final @@ -176,6 +194,31 @@ return success(); } +//===----------------------------------------------------------------------===// +// BranchOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult +BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp(op, op.getDest(), + adaptor.getDestOperands()); + return success(); +} + +//===----------------------------------------------------------------------===// +// CondBranchOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult CondBranchOpPattern::matchAndRewrite( + CondBranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(), + op.getFalseDest(), adaptor.getFalseDestOperands()); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -194,7 +237,8 @@ spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, - ReturnOpPattern, SelectOpPattern, SplatPattern>(typeConverter, context); + ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern, + CondBranchOpPattern>(typeConverter, context); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -572,6 +572,15 @@ return } +// CHECK-LABEL: @bit_cast +func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) { + // CHECK: spv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32> + %0 = arith.bitcast %arg0 : vector<2xf32> to vector<2xi32> + // CHECK: spv.Bitcast %{{.+}} : i64 to f64 + %1 = arith.bitcast %arg1 : i64 to f64 + return +} + // CHECK-LABEL: @fpext1 func @fpext1(%arg0: f16) -> f64 { // CHECK: spv.FConvert %{{.*}} : f16 to f64 diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -933,3 +933,51 @@ %splat = splat %f : vector<4xf32> return %splat : vector<4xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +// std.br, std.cond_br +//===----------------------------------------------------------------------===// + +module attributes { + spv.target_env = #spv.target_env<#spv.vce, {}> +} { + +// CHECK-LABEL: func @simple_loop +func @simple_loop() { +^bb0: +// CHECK-NEXT: spv.Branch ^bb1 + br ^bb1 + +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: {{.*}} = spv.Constant 1 : i32 +// CHECK-NEXT: {{.*}} = spv.Constant 42 : i32 +// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32) +^bb1: // pred: ^bb0 + %c1 = arith.constant 1 : index + %c42 = arith.constant 42 : index + br ^bb2(%c1 : index) + +// CHECK: ^bb2({{.*}}: i32): // 2 preds: ^bb1, ^bb3 +// CHECK-NEXT: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : i32 +// CHECK-NEXT: spv.BranchConditional {{.*}}, ^bb3, ^bb4 +^bb2(%0: index): // 2 preds: ^bb1, ^bb3 + %1 = arith.cmpi slt, %0, %c42 : index + cond_br %1, ^bb3, ^bb4 + +// CHECK: ^bb3: // pred: ^bb2 +// CHECK-NEXT: {{.*}} = spv.Constant 1 : i32 +// CHECK-NEXT: {{.*}} = spv.IAdd {{.*}}, {{.*}} : i32 +// CHECK-NEXT: spv.Branch ^bb2({{.*}} : i32) +^bb3: // pred: ^bb2 + %c1_0 = arith.constant 1 : index + %2 = arith.addi %0, %c1_0 : index + br ^bb2(%2 : index) + +// CHECK: ^bb4: // pred: ^bb2 +^bb4: // pred: ^bb2 + return +} + +}