diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -878,10 +878,19 @@ options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; SPIRVTypeConverter typeConverter(targetAttr, options); + // Use UnrealizedConversionCast as the bridge so that we don't need to pull + // in patterns for other dialects. + auto addUnrealizedCast = [](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + auto cast = builder.create(loc, type, inputs); + return Optional(cast.getResult(0)); + }; + typeConverter.addSourceMaterialization(addUnrealizedCast); + typeConverter.addTargetMaterialization(addUnrealizedCast); + target->addLegalOp(); + RewritePatternSet patterns(&getContext()); arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); - populateFuncToSPIRVPatterns(typeConverter, patterns); - populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); if (failed(applyPartialConversion(module, *target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -58,8 +58,10 @@ } // CHECK-LABEL: @index_scalar_srem -// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +// CHECK-SAME: (%[[A:.+]]: index, %[[B:.+]]: index) func.func @index_scalar_srem(%lhs: index, %rhs: index) { + // CHECK: %[[LHS:.+]] = builtin.unrealized_conversion_cast %[[A]] : index to i32 + // CHECK: %[[RHS:.+]] = builtin.unrealized_conversion_cast %[[B]] : index to i32 // CHECK: %[[LABS:.+]] = spv.GLSL.SAbs %[[LHS]] : i32 // CHECK: %[[RABS:.+]] = spv.GLSL.SAbs %[[RHS]] : i32 // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32 @@ -185,10 +187,8 @@ spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> } { -// expected-error @+1 {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'vector<4xi64>', with target type 'vector<4xi32>'}} func.func @int_vector4_invalid(%arg0: vector<4xi64>) { - // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}} - // expected-note @+1 {{see existing live user here}} + // expected-error @+1 {{bitwidth emulation is not implemented yet on unsigned op}} %0 = arith.divui %arg0, %arg0: vector<4xi64> return } @@ -837,8 +837,9 @@ } { // CHECK-LABEL: @fpext1 -// CHECK-SAME: %[[ARG:.*]]: f32 +// CHECK-SAME: %[[A:.*]]: f16 func.func @fpext1(%arg0: f16) -> f64 { + // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f16 to f32 // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64 %0 = arith.extf %arg0 : f16 to f64 return %0: f64 @@ -863,8 +864,9 @@ } { // CHECK-LABEL: @fptrunc1 -// CHECK-SAME: %[[ARG:.*]]: f32 +// CHECK-SAME: %[[A:.*]]: f64 func.func @fptrunc1(%arg0 : f64) -> f16 { + // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f64 to f32 // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16 %0 = arith.truncf %arg0 : f64 to f16 return %0: f16 @@ -1110,10 +1112,8 @@ spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> } { -// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}} func.func @int_vector4_invalid(%arg0: vector<4xi64>) { - // expected-error@below {{bitwidth emulation is not implemented yet on unsigned op}} - // expected-note@below {{see existing live user here}} + // expected-error@+1 {{bitwidth emulation is not implemented yet on unsigned op}} %0 = arith.divui %arg0, %arg0: vector<4xi64> return } @@ -1733,8 +1733,9 @@ } { // CHECK-LABEL: @fpext1 -// CHECK-SAME: %[[ARG:.*]]: f32 +// CHECK-SAME: %[[A:.*]]: f16 func.func @fpext1(%arg0: f16) -> f64 { + // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f16 to f32 // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f64 %0 = arith.extf %arg0 : f16 to f64 return %0: f64 @@ -1759,8 +1760,9 @@ } { // CHECK-LABEL: @fptrunc1 -// CHECK-SAME: %[[ARG:.*]]: f32 +// CHECK-SAME: %[[A:.*]]: f64 func.func @fptrunc1(%arg0 : f64) -> f16 { + // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f64 to f32 // CHECK-NEXT: spv.FConvert %[[ARG]] : f32 to f16 %0 = arith.truncf %arg0 : f64 to f16 return %0: f16