diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOCLOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOCLOps.td @@ -0,0 +1,169 @@ +//===- SPIRVOCLOps.td - OpenCL extended insts spec file ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of OpenCL extension ops. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_OCL_OPS +#define SPIRV_OCL_OPS + +include "mlir/Dialect/SPIRV/SPIRVBase.td" + +//===----------------------------------------------------------------------===// +// SPIR-V OpenCL opcode specification. +//===----------------------------------------------------------------------===// + +// Base class for all OpenCL ops. +class SPV_OCLOp traits = []> : + SPV_ExtInstOp; + +// Base class for OpenCL unary ops. +class SPV_OCLUnaryOp traits = []> : + SPV_OCLOp { + + let arguments = (ins + SPV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return parseUnaryOp(parser, result); }]; + + let printer = [{ return printUnaryOp(getOperation(), p); }]; + + let verifier = [{ return success(); }]; +} + +// Base class for OpenCL Unary arithmetic ops where return type matches +// the operand type. +class SPV_OCLUnaryArithmeticOp traits = []> : + SPV_OCLUnaryOp; + +// Base class for OpenCL binary ops. +class SPV_OCLBinaryOp traits = []> : + SPV_OCLOp { + + let arguments = (ins + SPV_ScalarOrVectorOf:$lhs, + SPV_ScalarOrVectorOf:$rhs + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; + + let printer = [{ return impl::printOneResultOp(getOperation(), p); }]; + + let verifier = [{ return success(); }]; +} + +// Base class for OpenCL Binary arithmetic ops where operand types and +// return type matches. +class SPV_OCLBinaryArithmeticOp traits = []> : + SPV_OCLBinaryOp; + +// ----- + +def SPV_OCLExpOp : SPV_OCLUnaryArithmeticOp<"exp", 19, SPV_Float> { + let summary = "Exponentiation of Operand 1"; + + let description = [{ + Compute the base-e exponential of x. (i.e. ex) + + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, + must be of the same type. + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + exp-op ::= ssa-id `=` `spv.OCL.exp` ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.OCL.exp %0 : f32 + %3 = spv.OCL.exp %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_OCLFAbsOp : SPV_OCLUnaryArithmeticOp<"fabs", 23, SPV_Float> { + let summary = "Absolute value of operand"; + + let description = [{ + Compute the absolute value of x. + + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, + must be of the same type. + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + abs-op ::= ssa-id `=` `spv.OCL.fabs` ssa-use `:` + float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.OCL.fabs %0 : f32 + %3 = spv.OCL.fabs %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_OCLSAbsOp : SPV_OCLUnaryArithmeticOp<"s_abs", 141, SPV_Integer> { + let summary = "Absolute value of operand"; + + let description = [{ + Returns |x|, where x is treated as signed integer. + + Result Type and x must be integer or vector(2,3,4,8,16) of + integer values. + + All of the operands, including the Result Type operand, + must be of the same type. + + + ``` + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + abs-op ::= ssa-id `=` `spv.OCL.s_abs` ssa-use `:` + integer-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.OCL.s_abs %0 : i32 + %3 = spv.OCL.s_abs %1 : vector<3xi16> + ``` + }]; +} + +#endif // SPIRV_OCL_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -34,6 +34,7 @@ include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" include "mlir/Dialect/SPIRV/SPIRVMatrixOps.td" include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td" +include "mlir/Dialect/SPIRV/SPIRVOCLOps.td" include "mlir/Dialect/SPIRV/SPIRVStructureOps.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s + +spv.module Physical64 OpenCL requires #spv.vce { + spv.func @float_insts(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spv.OCL.exp {{%.*}} : f32 + %0 = spv.OCL.exp %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.fabs {{%.*}} : f32 + %1 = spv.OCL.fabs %arg0 : f32 + spv.Return + } + + spv.func @integer_insts(%arg0 : i32) "None" { + // CHECK: {{%.*}} = spv.OCL.s_abs {{%.*}} : i32 + %0 = spv.OCL.s_abs %arg0 : i32 + spv.Return + } +} diff --git a/mlir/test/Dialect/SPIRV/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/ocl-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/ocl-ops.mlir @@ -0,0 +1,168 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.OCL.exp +//===----------------------------------------------------------------------===// + +func @exp(%arg0 : f32) -> () { + // CHECK: spv.OCL.exp {{%.*}} : f32 + %2 = spv.OCL.exp %arg0 : f32 + return +} + +func @expvec(%arg0 : vector<3xf16>) -> () { + // CHECK: spv.OCL.exp {{%.*}} : vector<3xf16> + %2 = spv.OCL.exp %arg0 : vector<3xf16> + return +} + +// ----- + +func @exp(%arg0 : i32) -> () { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %2 = spv.OCL.exp %arg0 : i32 + return +} + +// ----- + +func @exp(%arg0 : vector<5xf32>) -> () { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} + %2 = spv.OCL.exp %arg0 : vector<5xf32> + return +} + +// ----- + +func @exp(%arg0 : f32, %arg1 : f32) -> () { + // expected-error @+1 {{expected ':'}} + %2 = spv.OCL.exp %arg0, %arg1 : i32 + return +} + +// ----- + +func @exp(%arg0 : i32) -> () { + // expected-error @+2 {{expected non-function type}} + %2 = spv.OCL.exp %arg0 : + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.OCL.fabs +//===----------------------------------------------------------------------===// + +func @fabs(%arg0 : f32) -> () { + // CHECK: spv.OCL.fabs {{%.*}} : f32 + %2 = spv.OCL.fabs %arg0 : f32 + return +} + +func @fabsvec(%arg0 : vector<3xf16>) -> () { + // CHECK: spv.OCL.fabs {{%.*}} : vector<3xf16> + %2 = spv.OCL.fabs %arg0 : vector<3xf16> + return +} + +func @fabsf64(%arg0 : f64) -> () { + // CHECK: spv.OCL.fabs {{%.*}} : f64 + %2 = spv.OCL.fabs %arg0 : f64 + return +} + +// ----- + +func @fabs(%arg0 : i32) -> () { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %2 = spv.OCL.fabs %arg0 : i32 + return +} + +// ----- + +func @fabs(%arg0 : vector<5xf32>) -> () { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} + %2 = spv.OCL.fabs %arg0 : vector<5xf32> + return +} + +// ----- + +func @fabs(%arg0 : f32, %arg1 : f32) -> () { + // expected-error @+1 {{expected ':'}} + %2 = spv.OCL.fabs %arg0, %arg1 : i32 + return +} + +// ----- + +func @fabs(%arg0 : i32) -> () { + // expected-error @+2 {{expected non-function type}} + %2 = spv.OCL.fabs %arg0 : + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.OCL.s_abs +//===----------------------------------------------------------------------===// + +func @sabs(%arg0 : i32) -> () { + // CHECK: spv.OCL.s_abs {{%.*}} : i32 + %2 = spv.OCL.s_abs %arg0 : i32 + return +} + +func @sabsvec(%arg0 : vector<3xi16>) -> () { + // CHECK: spv.OCL.s_abs {{%.*}} : vector<3xi16> + %2 = spv.OCL.s_abs %arg0 : vector<3xi16> + return +} + +func @sabsi64(%arg0 : i64) -> () { + // CHECK: spv.OCL.s_abs {{%.*}} : i64 + %2 = spv.OCL.s_abs %arg0 : i64 + return +} + +func @sabsi8(%arg0 : i8) -> () { + // CHECK: spv.OCL.s_abs {{%.*}} : i8 + %2 = spv.OCL.s_abs %arg0 : i8 + return +} + +// ----- + +func @sabs(%arg0 : f32) -> () { + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}} + %2 = spv.OCL.s_abs %arg0 : f32 + return +} + +// ----- + +func @sabs(%arg0 : vector<5xi32>) -> () { + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + %2 = spv.OCL.s_abs %arg0 : vector<5xi32> + return +} + +// ----- + +func @sabs(%arg0 : i32, %arg1 : i32) -> () { + // expected-error @+1 {{expected ':'}} + %2 = spv.OCL.s_abs %arg0, %arg1 : i32 + return +} + +// ----- + +func @sabs(%arg0 : i32) -> () { + // expected-error @+2 {{expected non-function type}} + %2 = spv.OCL.s_abs %arg0 : + return +} +