diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -1155,13 +1155,6 @@ // Use UnrealizedConversionCast as the bridge so that we don't need to pull // in patterns for other dialects. - auto addUnrealizedCast = [](OpBuilder &builder, Type type, - ValueRange inputs, Location loc) { - auto cast = builder.create(loc, type, inputs); - return std::optional(cast.getResult(0)); - }; - typeConverter.addSourceMaterialization(addUnrealizedCast); - typeConverter.addTargetMaterialization(addUnrealizedCast); target->addLegalOp(); // Fail hard when there are any remaining 'arith' ops. diff --git a/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp --- a/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.cpp @@ -40,13 +40,6 @@ // Use UnrealizedConversionCast as the bridge so that we don't need to pull // in patterns for other dialects. - auto addUnrealizedCast = [](OpBuilder &builder, Type type, - ValueRange inputs, Location loc) { - auto cast = builder.create(loc, type, inputs); - return std::optional(cast.getResult(0)); - }; - typeConverter.addSourceMaterialization(addUnrealizedCast); - typeConverter.addTargetMaterialization(addUnrealizedCast); target->addLegalOp(); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp @@ -44,13 +44,6 @@ // Use UnrealizedConversionCast as the bridge so that we don't need to pull // in patterns for other dialects. - auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) { - auto cast = builder.create(loc, type, inputs); - return std::optional(cast.getResult(0)); - }; - typeConverter.addSourceMaterialization(addUnrealizedCast); - typeConverter.addTargetMaterialization(addUnrealizedCast); target->addLegalOp(); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -490,14 +490,6 @@ result = rewriter.create(loc, dstType, result, shiftValue); - if (isBool) { - dstType = typeConverter.convertType(loadOp.getType()); - mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); - result = rewriter.create(loc, result, mask); - } else if (result.getType().getIntOrFloatBitWidth() != - static_cast(dstBits)) { - result = rewriter.create(loc, dstType, result); - } rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp @@ -45,13 +45,6 @@ // Use UnrealizedConversionCast as the bridge so that we don't need to pull in // patterns for other dialects. - auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) { - auto cast = builder.create(loc, type, inputs); - return std::optional(cast.getResult(0)); - }; - typeConverter.addSourceMaterialization(addUnrealizedCast); - typeConverter.addTargetMaterialization(addUnrealizedCast); target->addLegalOp(); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp @@ -44,13 +44,6 @@ // Use UnrealizedConversionCast as the bridge so that we don't need to pull in // patterns for other dialects. - auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) { - auto cast = builder.create(loc, type, inputs); - return std::optional(cast.getResult(0)); - }; - typeConverter.addSourceMaterialization(addUnrealizedCast); - typeConverter.addTargetMaterialization(addUnrealizedCast); target->addLegalOp(); RewritePatternSet patterns(context); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -565,6 +565,84 @@ return wrapInStructAndGetPointer(arrayType, storageClass); } +//===----------------------------------------------------------------------===// +// Type casting materialization +//===----------------------------------------------------------------------===// + +/// Converts the given `inputs` to the original source `type` considering the +/// `targetEnv`'s capabilities. +/// +/// This function is meant to be used for source materialization in type +/// converters. When the type converter needs to materialize a cast op back +/// to some original source type, we need to check whether the original source +/// type is supported in the target environment. If so, we can insert legal +/// SPIR-V cast ops accordingly. +/// +/// Note that in SPIR-V the capabilities for storage and compute are separate. +/// This function is meant to handle the **compute** side; so it does not +/// involve storage classes in its logic. The storage side is expected to be +/// handled by MemRef conversion logic. +std::optional castToSourceType(const spirv::TargetEnv &targetEnv, + OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + // We can only cast one value in SPIR-V. + if (inputs.size() != 1) { + auto castOp = builder.create(loc, type, inputs); + return castOp.getResult(0); + } + Value input = inputs.front(); + + // Only support integer types for now. Floating point types to be implemented. + if (!isa(type)) { + auto castOp = builder.create(loc, type, inputs); + return castOp.getResult(0); + } + auto inputType = cast(input.getType()); + + auto scalarType = dyn_cast(type); + if (!scalarType) { + auto castOp = builder.create(loc, type, inputs); + return castOp.getResult(0); + } + + // Only support source type with a smaller bitwidth. This would mean we are + // truncating to go back so we don't need to worry about the signedness. + // For extension, we cannot have enough signal here to decide which op to use. + if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) { + auto castOp = builder.create(loc, type, inputs); + return castOp.getResult(0); + } + + // Boolean values would need to use different ops than normal integer values. + if (type.isInteger(1)) { + Value one = spirv::ConstantOp::getOne(inputType, loc, builder); + return builder.create(loc, input, one); + } + + // Check that the source integer type is supported by the environment. + SmallVector, 1> exts; + SmallVector, 2> caps; + scalarType.getExtensions(exts); + scalarType.getCapabilities(caps); + if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || + failed(checkExtensionRequirements(type, targetEnv, exts))) { + auto castOp = builder.create(loc, type, inputs); + return castOp.getResult(0); + } + + // We've already made sure this is truncating previously, so we don't need to + // care about signedness here. Still try to use a corresponding op for better + // consistency though. + if (type.isSignedInteger()) { + return builder.create(loc, type, input); + } + return builder.create(loc, type, input); +} + +//===----------------------------------------------------------------------===// +// SPIRVTypeConverter +//===----------------------------------------------------------------------===// + SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options) : targetEnv(targetAttr), options(options) { @@ -611,6 +689,17 @@ addConversion([this](MemRefType memRefType) { return convertMemrefType(this->targetEnv, this->options, memRefType); }); + + // Register some last line of defense casting logic. + addSourceMaterialization( + [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { + return castToSourceType(this->targetEnv, builder, type, inputs, loc); + }); + addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) { + auto cast = builder.create(loc, type, inputs); + return std::optional(cast.getResult(0)); + }); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -297,7 +297,8 @@ // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 // CHECK: %[[T2:.+]] = spirv.Constant 24 : i32 // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 - // CHECK: spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + // CHECK: builtin.unrealized_conversion_cast %[[SR]] %0 = memref.load %arg0[] : memref> return %0 : i8 } @@ -321,7 +322,8 @@ // CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32 // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[SIXTEEN]] : i32, i32 - // CHECK: spirv.ShiftRightArithmetic %[[T3]], %[[SIXTEEN]] : i32, i32 + // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[SIXTEEN]] : i32, i32 + // CHECK: builtin.unrealized_conversion_cast %[[SR]] %0 = memref.load %arg0[%index] : memref<10xi16, #spirv.storage_class> return %0: i16 } @@ -448,7 +450,8 @@ // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 // CHECK: %[[C28:.+]] = spirv.Constant 28 : i32 // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[AND]], %[[C28]] : i32, i32 - // CHECK: spirv.ShiftRightArithmetic %[[SL]], %[[C28]] : i32, i32 + // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[C28]] : i32, i32 + // CHECK: builtin.unrealized_conversion_cast %[[SR]] %0 = memref.load %arg0[%i] : memref> return %0 : i4 } @@ -479,3 +482,41 @@ } } // end module + +// ----- + +// Check that casts are properly inserted if the corresponding **compute** +// capability is allowed. +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @load_i1 +func.func @load_i1(%arg0: memref>) -> i1 { + // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 + // CHECK: %[[RES:.+]] = spirv.IEqual %{{.+}}, %[[ONE]] : i32 + // CHECK: return %[[RES]] + %0 = memref.load %arg0[] : memref> + return %0 : i1 +} + +// CHECK-LABEL: @load_i8 +func.func @load_i8(%arg0: memref>) -> i8 { + // CHECK: %[[RES:.+]] = spirv.UConvert %{{.+}} : i32 to i8 + // CHECK: return %[[RES]] + %0 = memref.load %arg0[] : memref> + return %0 : i8 +} + +// CHECK-LABEL: @load_i16 +func.func @load_i16(%arg0: memref<10xi16, #spirv.storage_class>, %index : index) -> i16 { + // CHECK: %[[RES:.+]] = spirv.UConvert %{{.+}} : i32 to i16 + // CHECK: return %[[RES]] + %0 = memref.load %arg0[%index] : memref<10xi16, #spirv.storage_class> + return %0: i16 +} + +} // end module