|
| 1 | +use chalk_ir::{DebruijnIndex, WhereClause}; |
| 2 | +use hir_def::{ |
| 3 | + lang_item::LangItem, AssocItemId, ConstId, FunctionId, GenericDefId, HasModule, TraitId, |
| 4 | + TypeAliasId, |
| 5 | +}; |
| 6 | + |
| 7 | +use crate::{ |
| 8 | + all_super_traits, db::HirDatabase, generics::generics, layout::LayoutError, |
| 9 | + lower::callable_item_sig, make_single_type_binders, static_lifetime, wrap_empty_binders, DynTy, |
| 10 | + Interner, QuantifiedWhereClauses, Substitution, TyBuilder, TyKind, |
| 11 | +}; |
| 12 | + |
| 13 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 14 | +pub enum ObjectSafetyError { |
| 15 | + LayoutError(LayoutError), |
| 16 | +} |
| 17 | + |
| 18 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 19 | +pub enum ObjectSafetyViolation { |
| 20 | + SizedSelf, |
| 21 | + SelfReferencial, |
| 22 | + NonLifetimeBinder, |
| 23 | + Method(FunctionId, MethodViolationCode), |
| 24 | + AssocConst(ConstId), |
| 25 | + GAT(TypeAliasId), |
| 26 | + // This doesn't exist in rustc, but added for better visualization |
| 27 | + HasNonSafeSuperTrait(TraitId), |
| 28 | +} |
| 29 | + |
| 30 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 31 | +pub enum MethodViolationCode { |
| 32 | + StaticMethod, |
| 33 | + ReferencesSelfInput, |
| 34 | + ReferencesSelfOutput, |
| 35 | + ReferencesImplTraitInTrait, |
| 36 | + AsyncFn, |
| 37 | + WhereClauseReferencesSelf, |
| 38 | + Generic, |
| 39 | + UndispatchableReceiver, |
| 40 | + HasNonLifetimeTypeParam, |
| 41 | + NonReceiverSelfParam, |
| 42 | +} |
| 43 | + |
| 44 | +// Basically, this is almost same as `rustc_trait_selection::traits::object_safety` |
| 45 | +// but some difference; |
| 46 | +// |
| 47 | +// 1. While rustc gathers almost every violation, but this only early return on |
| 48 | +// first violation for perf. |
| 49 | +// |
| 50 | +// These can be changed anytime while implementing. |
| 51 | +pub fn object_safety_of_trait_query( |
| 52 | + db: &dyn HirDatabase, |
| 53 | + trait_: TraitId, |
| 54 | +) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> { |
| 55 | + for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1) { |
| 56 | + if db.object_safety_of_trait(super_trait)?.is_some() { |
| 57 | + return Ok(Some(ObjectSafetyViolation::HasNonSafeSuperTrait(super_trait))); |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + if generics_require_sized_self(db, trait_.into()) { |
| 62 | + return Ok(Some(ObjectSafetyViolation::SizedSelf)); |
| 63 | + } |
| 64 | + |
| 65 | + // TODO: bound referencing self |
| 66 | + |
| 67 | + // TODO: non lifetime binder |
| 68 | + |
| 69 | + let trait_data = db.trait_data(trait_); |
| 70 | + for (_, assoc_item) in &trait_data.items { |
| 71 | + let item_violation = object_safety_violation_for_assoc_item(db, trait_, *assoc_item)?; |
| 72 | + if item_violation.is_some() { |
| 73 | + return Ok(item_violation); |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + Ok(None) |
| 78 | +} |
| 79 | + |
| 80 | +fn generics_require_sized_self(db: &dyn HirDatabase, def: GenericDefId) -> bool { |
| 81 | + let krate = def.module(db.upcast()).krate(); |
| 82 | + let Some(_sized) = db.lang_item(krate, LangItem::Sized).and_then(|l| l.as_trait()) else { |
| 83 | + return false; |
| 84 | + }; |
| 85 | + |
| 86 | + let _predicates = db.generic_predicates(def); |
| 87 | + // TODO: elaborate with `utils::elaborate_clause_supertraits` and check `Self: Sized` |
| 88 | + |
| 89 | + false |
| 90 | +} |
| 91 | + |
| 92 | +fn object_safety_violation_for_assoc_item( |
| 93 | + db: &dyn HirDatabase, |
| 94 | + trait_: TraitId, |
| 95 | + item: AssocItemId, |
| 96 | +) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> { |
| 97 | + match item { |
| 98 | + AssocItemId::ConstId(it) => Ok(Some(ObjectSafetyViolation::AssocConst(it))), |
| 99 | + AssocItemId::FunctionId(it) => virtual_call_violations_for_method(db, trait_, it) |
| 100 | + .map(|v| v.map(|v| ObjectSafetyViolation::Method(it, v))) |
| 101 | + .map_err(ObjectSafetyError::LayoutError), |
| 102 | + AssocItemId::TypeAliasId(it) => { |
| 103 | + let generics = generics(db.upcast(), it.into()); |
| 104 | + // rustc checks if the `generic_associate_type_extended` feature gate is set |
| 105 | + if generics.len_self() > 0 && db.type_alias_impl_traits(it).is_none() { |
| 106 | + Ok(Some(ObjectSafetyViolation::GAT(it))) |
| 107 | + } else { |
| 108 | + Ok(None) |
| 109 | + } |
| 110 | + } |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +fn virtual_call_violations_for_method( |
| 115 | + db: &dyn HirDatabase, |
| 116 | + trait_: TraitId, |
| 117 | + func: FunctionId, |
| 118 | +) -> Result<Option<MethodViolationCode>, LayoutError> { |
| 119 | + let func_data = db.function_data(func); |
| 120 | + if !func_data.has_self_param() { |
| 121 | + return Ok(Some(MethodViolationCode::StaticMethod)); |
| 122 | + } |
| 123 | + |
| 124 | + // TODO: check self reference in params |
| 125 | + |
| 126 | + // TODO: check self reference in return type |
| 127 | + |
| 128 | + // TODO: check asyncness, RPIT |
| 129 | + |
| 130 | + let generic_params = db.generic_params(func.into()); |
| 131 | + if generic_params.len_type_or_consts() > 0 { |
| 132 | + return Ok(Some(MethodViolationCode::Generic)); |
| 133 | + } |
| 134 | + |
| 135 | + // Check if the receiver is a correct type like `Self`, `Box<Self>`, `Arc<Self>`, etc |
| 136 | + // |
| 137 | + // TODO: rustc does this in two steps :thinking_face: |
| 138 | + // I'm doing only the second, real one, layout check |
| 139 | + // TODO: clean all the messes for building receiver types to check layout of |
| 140 | + |
| 141 | + // Check for types like `Rc<()>` |
| 142 | + let sig = callable_item_sig(db, func.into()); |
| 143 | + // TODO: Getting receiver type that substituted `Self` by `()`. there might be more clever way? |
| 144 | + let subst = Substitution::from_iter( |
| 145 | + Interner, |
| 146 | + std::iter::repeat(TyBuilder::unit()).take(sig.len(Interner)), |
| 147 | + ); |
| 148 | + let sig = sig.substitute(Interner, &subst); |
| 149 | + let receiver_ty = sig.params()[0].to_owned(); |
| 150 | + let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?; |
| 151 | + |
| 152 | + if !matches!(layout.abi, rustc_abi::Abi::Scalar(..)) { |
| 153 | + return Ok(Some(MethodViolationCode::UndispatchableReceiver)); |
| 154 | + } |
| 155 | + |
| 156 | + // Check for types like `Rc<dyn Trait>` |
| 157 | + // TODO: `dyn Trait` and receiver type building is a total mess |
| 158 | + let trait_ref = |
| 159 | + TyBuilder::trait_ref(db, trait_).fill_with_bound_vars(DebruijnIndex::INNERMOST, 0).build(); |
| 160 | + let bound = wrap_empty_binders(WhereClause::Implemented(trait_ref)); |
| 161 | + let bounds = QuantifiedWhereClauses::from_iter(Interner, [bound]); |
| 162 | + let dyn_trait = TyKind::Dyn(DynTy { |
| 163 | + bounds: make_single_type_binders(bounds), |
| 164 | + lifetime: static_lifetime(), |
| 165 | + }) |
| 166 | + .intern(Interner); |
| 167 | + let sig = callable_item_sig(db, func.into()); |
| 168 | + let subst = Substitution::from_iter( |
| 169 | + Interner, |
| 170 | + std::iter::once(dyn_trait) |
| 171 | + .chain(std::iter::repeat(TyBuilder::unit())) |
| 172 | + .take(sig.len(Interner)), |
| 173 | + ); |
| 174 | + let sig = sig.substitute(Interner, &subst); |
| 175 | + let receiver_ty = sig.params()[0].to_owned(); |
| 176 | + let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?; |
| 177 | + |
| 178 | + if !matches!(layout.abi, rustc_abi::Abi::ScalarPair(..)) { |
| 179 | + return Ok(Some(MethodViolationCode::UndispatchableReceiver)); |
| 180 | + } |
| 181 | + |
| 182 | + Ok(None) |
| 183 | +} |
0 commit comments